merge: sync upstream/main before PR

This commit is contained in:
Wang Lvyuan
2026-03-19 16:37:28 +08:00
107 changed files with 2973 additions and 341 deletions

View File

@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
@@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()

View File

@@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -755,31 +756,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -788,32 +789,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_model",
@@ -828,17 +829,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
},
},
}

View File

@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64
request_id *string
model *string
upstream_model *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil
}
// SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s
}
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
v := m.upstream_model
if v == nil {
return
}
return *v, true
}
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
}
return oldValue.UpstreamModel, nil
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (m *UsageLogMutation) ClearUpstreamModel() {
m.upstream_model = nil
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
}
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
func (m *UsageLogMutation) UpstreamModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
return ok
}
// ResetUpstreamModel resets all changes to the "upstream_model" field.
func (m *UsageLogMutation) ResetUpstreamModel() {
m.upstream_model = nil
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 32)
fields := make([]string, 0, 33)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil {
fields = append(fields, usagelog.FieldModel)
}
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID()
case usagelog.FieldModel:
return m.Model()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx)
case usagelog.FieldModel:
return m.OldModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetModel(v)
return nil
case usagelog.FieldUpstreamModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUpstreamModel(v)
return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation.
func (m *UsageLogMutation) ClearedFields() []string {
var fields []string
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error {
switch name {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel:
m.ResetModel()
return nil
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil

View File

@@ -821,92 +821,96 @@ func init() {
return nil
}
}()
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[7].Descriptor()
usagelogDescInputTokens := usagelogFields[8].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[13].Descriptor()
usagelogDescInputCost := usagelogFields[14].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[14].Descriptor()
usagelogDescOutputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[17].Descriptor()
usagelogDescTotalCost := usagelogFields[18].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[18].Descriptor()
usagelogDescActualCost := usagelogFields[19].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[21].Descriptor()
usagelogDescBillingType := usagelogFields[22].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[22].Descriptor()
usagelogDescStream := usagelogFields[23].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[25].Descriptor()
usagelogDescUserAgent := usagelogFields[26].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[26].Descriptor()
usagelogDescIPAddress := usagelogFields[27].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[27].Descriptor()
usagelogDescImageCount := usagelogFields[28].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[28].Descriptor()
usagelogDescImageSize := usagelogFields[29].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
usagelogDescMediaType := usagelogFields[30].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model").
MaxLen(100).
NotEmpty(),
// UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model").
MaxLen(100).
Optional().
Nillable(),
field.Int64("group_id").
Optional().
Nillable(),

View File

@@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field.
Model string `json:"model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Model = value.String
}
case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
} else if value.Valid {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=")
builder.WriteString(_m.Model)
builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database.
FieldModel = "model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -135,6 +137,7 @@ var Columns = []string{
FieldAccountID,
FieldRequestID,
FieldModel,
FieldUpstreamModel,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -179,6 +182,8 @@ var (
RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc()
}
// ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
}
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
}
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
func UpstreamModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
}
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
func UpstreamModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
}
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
}
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
func UpstreamModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
}
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
func UpstreamModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
}
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
func UpstreamModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
}
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
func UpstreamModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
}
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
func UpstreamModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
}
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
}
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
}
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
func UpstreamModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
}
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
func UpstreamModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
}
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
func UpstreamModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
}
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c
}
// SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v)
return _c
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetUpstreamModel(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value
}
if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v)
return u
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldUpstreamModel)
return u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldUpstreamModel)
return u
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}

View File

@@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
@@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
@@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
// Claude Haiku → Sonnet无 Haiku 支持)
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-6",
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",

View File

@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)

View File

@@ -273,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
modelSource := usagestats.ModelSourceRequested
var requestType *int16
var stream *bool
var billingType *int8
@@ -297,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
modelSource = rawModelSource
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
@@ -323,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -619,6 +627,12 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
}
}
dim.Model = c.Query("model")
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
dim.ModelType = rawModelSource
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")

View File

@@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsValidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{

View File

@@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
}
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
}
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)

View File

@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
ModelSource string `json:"model_source,omitempty"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
modelSource string,
requestType *int16,
stream *bool,
billingType *int8,
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
ModelSource: usagestats.NormalizeModelSource(modelSource),
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
})
if err != nil {
return nil, hit, err

View File

@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
usagestats.ModelSourceRequested,
filters.RequestType,
filters.Stream,
filters.BillingType,

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -16,7 +17,9 @@ import (
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
adminService service.AdminService
dashboardService *service.DashboardService
groupCapacityService *service.GroupCapacityService
}
type optionalLimitField struct {
@@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 {
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
adminService: adminService,
dashboardService: dashboardService,
groupCapacityService: groupCapacityService,
}
}
@@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
_ = groupID // TODO: implement actual stats
}
// GetUsageSummary returns today's and cumulative cost for all groups.
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
if err != nil {
response.Error(c, 500, "Failed to get group usage summary")
return
}
response.Success(c, results)
}
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
// GET /api/v1/admin/groups/capacity-summary
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get group capacity summary")
return
}
response.Success(c, results)
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {

View File

@@ -977,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"})
}
// GetOverloadCooldownSettings 获取529过载冷却配置
// GET /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: settings.Enabled,
CooldownMinutes: settings.CooldownMinutes,
})
}
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
type UpdateOverloadCooldownSettingsRequest struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// UpdateOverloadCooldownSettings 更新529过载冷却配置
// PUT /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
var req UpdateOverloadCooldownSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.OverloadCooldownSettings{
Enabled: req.Enabled,
CooldownMinutes: req.CooldownMinutes,
}
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: updatedSettings.Enabled,
CooldownMinutes: updatedSettings.CooldownMinutes,
})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {

View File

@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
}
status := c.Query("status")
platform := c.Query("platform")
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,

View File

@@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
Items []SoraS3Profile `json:"items"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`

View File

@@ -122,9 +122,11 @@ type AdminGroup struct {
DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
// 分组排序
SortOrder int `json:"sort_order"`
@@ -332,6 +334,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level.

View File

@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}

View File

@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
}`)
}
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
}
// body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
"system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
})
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)

View File

@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
@@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}

View File

@@ -3,6 +3,28 @@ package usagestats
import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
@@ -90,6 +112,13 @@ type EndpointStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupUsageSummary represents today's and cumulative cost for a single group.
type GroupUsageSummary struct {
GroupID int64 `json:"group_id"`
TodayCost float64 `json:"today_cost"`
TotalCost float64 `json:"total_cost"`
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
@@ -143,6 +172,7 @@ type UserBreakdownItem struct {
type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
}

View File

@@ -0,0 +1,47 @@
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}

View File

@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = total
out.ActiveAccountCount = active
return out, nil
}
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var rateLimited int64
err = scanSingleRow(ctx, r.sql,
`SELECT COUNT(*),
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
[]any{groupID}, &total, &active, &rateLimited)
return
}
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
counts = make(map[int64]int64, len(groupIDs))
type groupAccountCounts struct {
Total int64
Active int64
RateLimited int64
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
`SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
pq.Array(groupIDs),
)
if err != nil {
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
for rows.Next() {
var groupID int64
var count int64
if err = rows.Scan(&groupID, &count); err != nil {
var c groupAccountCounts
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
return nil, err
}
counts[groupID] = count
counts[groupID] = c
}
if err = rows.Err(); err != nil {
return nil, err

View File

@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint",
"text",
"text",
"text",
"bigint",
"bigint",
"integer",
@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*38)
args := make([]any, 0, len(keys)*39)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*38)
args := make([]any, 0, len(preparedList)*39)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any
if requestID != "" {
@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID,
requestIDArg,
log.Model,
upstreamModel,
groupID,
subscriptionID,
log.InputTokens,
@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
SELECT
model,
%s as model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
`, modelExpr, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
args = append(args, dim.GroupID)
}
if dim.Model != "" {
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1)
query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model)
}
if dim.Endpoint != "" {
@@ -3067,6 +3089,53 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
return results, nil
}
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
// todayStart is the start-of-day in the caller's timezone (UTC-based).
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
// or a materialized view / pre-aggregation table for cumulative costs.
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
query := `
SELECT
g.id AS group_id,
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
FROM groups g
LEFT JOIN usage_logs ul ON ul.group_id = g.id
GROUP BY g.id
`
rows, err := r.sql.QueryContext(ctx, query, todayStart)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []usagestats.GroupUsageSummary
for rows.Next() {
var row usagestats.GroupUsageSummary
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string {
switch endpointType {
@@ -3819,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64
requestID sql.NullString
model string
upstreamModel sql.NullString
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
@@ -3861,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID,
&requestID,
&model,
&upstreamModel,
&groupID,
&subscriptionID,
&inputTokens,
@@ -3973,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil
}

View File

@@ -5,6 +5,7 @@ package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
@@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) {
{"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback
{"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback
}
for _, tc := range tests {
@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) {
})
}
}
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32),
sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,

View File

@@ -5,6 +5,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check
now := time.Now()

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
return false, errors.New("not implemented")
}
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, errors.New("not implemented")
}
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
@@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string

View File

@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
@@ -400,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)

View File

@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
// generateSessionString generates a Claude Code style session string.
// The output format is determined by the UA version in claude.DefaultHeaders,
// ensuring consistency between the user_id format and the UA sent to upstream.
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes)
hex64 := hex.EncodeToString(b)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
}
// createTestPayload creates a Claude Code style test request payload

View File

@@ -49,6 +49,7 @@ type UsageLogRepository interface {
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)

View File

@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {

View File

@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call")
}
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",

View File

@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false
}
if !userIDPattern.MatchString(userID) {
if ParseMetadataUserID(userID) == nil {
return false
}
@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
return ExtractCLIVersion(ua)
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中

View File

@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {

View File

@@ -170,6 +170,13 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
// =========================
// Overload Cooldown (529)
// =========================
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
// =========================
// Stream Timeout Handling
// =========================

View File

@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{},
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
}
svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token")
@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
}
account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed")
@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream,
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "empty response")

View File

@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil

View File

@@ -28,6 +28,12 @@ var (
patternEmptyContentSpaced = []byte(`"content": []`)
patternEmptyContentSp1 = []byte(`"content" : []`)
patternEmptyContentSp2 = []byte(`"content" :[]`)
// Fast-path patterns for empty text blocks: {"type":"text","text":""}
patternEmptyText = []byte(`"text":""`)
patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`)
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
@@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternThinkingField) ||
bytes.Contains(body, patternThinkingFieldSpaced)
// Also check for empty content arrays that need fixing.
// Also check for empty content arrays and empty text blocks that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
bytes.Contains(body, patternEmptyContentSpaced) ||
bytes.Contains(body, patternEmptyContentSp1) ||
bytes.Contains(body, patternEmptyContentSp2)
// Check for empty text blocks: {"type":"text","text":""}
// These cause upstream 400: "text content blocks must be non-empty"
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock {
return body
}
@@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingFieldSpaced)
if !hasEmptyContent && !containsThinkingBlocks {
if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
out = removeThinkingDependentContextStrategies(out)
@@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string)
// Strip empty text blocks: {"type":"text","text":""}
// Upstream rejects these with 400: "text content blocks must be non-empty"
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
modifiedThisMsg = true
ensureNewContent(bi)
continue
}
}
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType {
case "thinking":

View File

@@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T)
require.NotEmpty(t, content0["text"])
}
func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
// Empty text blocks cause upstream 400: "text content blocks must be non-empty"
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]},
{"role":"assistant","content":[{"type":"text","text":""}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
// First message: empty text block stripped, "hello" preserved
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
require.Equal(t, "hello", content0[0].(map[string]any)["text"])
// Second message: only had empty text block → gets placeholder
msg1 := msgs[1].(map[string]any)
content1 := msg1["content"].([]any)
require.Len(t, content1, 1)
block1 := content1[0].(map[string]any)
require.Equal(t, "text", block1["type"])
require.NotEmpty(t, block1["text"])
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
// Fast path: no thinking content, no empty content, no empty text blocks → unchanged
require.Equal(t, input, out)
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},

View File

@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
@@ -491,6 +490,7 @@ type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
@@ -644,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
return uid.SessionID
}
}
@@ -1026,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
sessionID = generateSessionUUID(seed)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
if accountUUID != "" {
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
// 根据指纹 UA 版本选择输出格式
var uaVersion string
if fp != nil {
uaVersion = ExtractCLIVersion(fp.UserAgent)
}
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
@@ -3989,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
passthroughModel = mappedModel
}
}
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: passthroughBody,
RequestModel: passthroughModel,
OriginalModel: parsed.Model,
RequestStream: parsed.Stream,
StartTime: startTime,
})
}
if account != nil && account.IsBedrock() {
@@ -4513,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
UpstreamModel: mappedModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -4520,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
type anthropicPassthroughForwardInput struct {
Body []byte
RequestModel string
OriginalModel string
RequestStream bool
StartTime time.Time
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
reqModel string,
originalModel string,
reqStream bool,
startTime time.Time,
) (*ForwardResult, error) {
return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: body,
RequestModel: reqModel,
OriginalModel: originalModel,
RequestStream: reqStream,
StartTime: startTime,
})
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
ctx context.Context,
c *gin.Context,
account *Account,
input anthropicPassthroughForwardInput,
) (*ForwardResult, error) {
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -4543,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
}
logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v",
account.ID, account.Name, reqModel, reqStream)
account.ID, account.Name, input.RequestModel, input.RequestStream)
if c != nil {
c.Set("anthropic_passthrough", true)
}
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body)
setOpsUpstreamRequestBody(c, input.Body)
var resp *http.Response
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token)
releaseUpstreamCtx()
if err != nil {
return nil, err
@@ -4713,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel)
if input.RequestStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel)
if err != nil {
return nil, err
}
@@ -4734,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
Stream: reqStream,
Duration: time.Since(startTime),
Model: input.OriginalModel,
UpstreamModel: input.RequestModel,
Stream: input.RequestStream,
Duration: time.Since(input.StartTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
@@ -5241,6 +5273,7 @@ func (s *GatewayService) forwardBedrock(
RequestID: resp.Header.Get("x-amzn-requestid"),
Usage: *usage,
Model: reqModel,
UpstreamModel: mappedModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -5533,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody
}
}
@@ -6068,9 +6101,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
// "messages: text content blocks must be non-empty"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") ||
strings.Contains(msg, "content blocks must be non-empty") {
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error")
return true
}
@@ -7529,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
@@ -7710,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
@@ -8161,7 +8198,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody
}
}

View File

@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil

View File

@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext")
}
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{}

View File

@@ -64,8 +64,10 @@ type Group struct {
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
AccountGroups []AccountGroup
AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
}
func (g *Group) IsActive() bool {

View File

@@ -0,0 +1,131 @@
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}

View File

@@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
}
// 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account count: %w", err)
}

View File

@@ -19,10 +19,6 @@ import (
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID}
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash}
// 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 根据 fingerprintUA 版本选择输出格式。
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil
}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
// 解析 user_id兼容旧拼接格式和新 JSON 格式)
parsed := ParseMetadataUserID(userID)
if parsed == nil {
return body, nil
}
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
sessionTail := parsed.SessionID // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
// 根据客户端版本选择输出格式
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
metadata["user_id"] = newUserID
@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil {
return newBody, err
}
@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
// 解析已重写的 user_id
uidParsed := ParseMetadataUserID(userID)
if uidParsed == nil {
return newBody, nil
}
@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
// 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied",
"account_id", account.ID,

View File

@@ -0,0 +1,104 @@
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}

View File

@@ -0,0 +1,183 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}

View File

@@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
@@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}

View File

@@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
c.JSON(http.StatusOK, anthropicResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
@@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}

View File

@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) {
func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
ServiceTier: &serviceTier,
ReasoningEffort: &reasoning,
Usage: OpenAIUsage{
@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)

View File

@@ -216,6 +216,9 @@ type OpenAIForwardResult struct {
// This is set by the Anthropic Messages conversion path where
// the mapped upstream model differs from the client-facing model.
BillingModel string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Empty when no mapping was applied (requested model was used as-is).
UpstreamModel string
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
// Nil means the request did not specify a recognized tier.
ServiceTier *string
@@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs,
wsAttempts,
)
wsResult.UpstreamModel = mappedModel
return wsResult, nil
}
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
@@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
@@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: billingModel,
Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return ""
}
}
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}

View File

@@ -0,0 +1,298 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// errSettingRepo: a SettingRepository that always returns errors on read
// ---------------------------------------------------------------------------
type errSettingRepo struct {
mockSettingRepo // embed the existing mock from backup_service_test.go
readErr error
}
func (r *errSettingRepo) GetValue(_ context.Context, _ string) (string, error) {
return "", r.readErr
}
func (r *errSettingRepo) Get(_ context.Context, _ string) (*Setting, error) {
return nil, r.readErr
}
// ---------------------------------------------------------------------------
// overloadAccountRepoStub: records SetOverloaded calls
// ---------------------------------------------------------------------------
type overloadAccountRepoStub struct {
mockAccountRepoForGemini
overloadCalls int
lastOverloadID int64
lastOverloadEnd time.Time
}
func (r *overloadAccountRepoStub) SetOverloaded(_ context.Context, id int64, until time.Time) error {
r.overloadCalls++
r.lastOverloadID = id
r.lastOverloadEnd = until
return nil
}
// ===========================================================================
// SettingService: GetOverloadCooldownSettings
// ===========================================================================
func TestGetOverloadCooldownSettings_DefaultsWhenNotSet(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ReadsFromDB(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 30})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 30, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMinValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 0})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 1, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMaxValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 999})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 120, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_InvalidJSON_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = "not-json"
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_EmptyValue_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = ""
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
// ===========================================================================
// SettingService: SetOverloadCooldownSettings
// ===========================================================================
func TestSetOverloadCooldownSettings_Success(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false,
CooldownMinutes: 25,
})
require.NoError(t, err)
// Verify round-trip
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 25, settings.CooldownMinutes)
}
func TestSetOverloadCooldownSettings_RejectsNil(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), nil)
require.Error(t, err)
}
func TestSetOverloadCooldownSettings_EnabledRejectsOutOfRange(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{0, -1, 121, 999} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.Error(t, err, "should reject enabled=true + cooldown_minutes=%d", minutes)
require.Contains(t, err.Error(), "cooldown_minutes must be between 1-120")
}
}
func TestSetOverloadCooldownSettings_DisabledNormalizesOutOfRange(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
// enabled=false + cooldown_minutes=0 应该保存成功值被归一化为10
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false, CooldownMinutes: 0,
})
require.NoError(t, err, "disabled with invalid minutes should NOT be rejected")
// 验证持久化后读回来的值
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes, "should be normalized to default")
}
func TestSetOverloadCooldownSettings_AcceptsBoundaries(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{1, 60, 120} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.NoError(t, err, "should accept cooldown_minutes=%d", minutes)
}
}
// ===========================================================================
// RateLimitService: handle529 behaviour
// ===========================================================================
func TestHandle529_EnabledFromDB_PausesAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.Equal(t, int64(42), accountRepo.lastOverloadID)
require.WithinDuration(t, before.Add(15*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DisabledFromDB_SkipsAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
svc.handle529(context.Background(), account)
require.Equal(t, 0, accountRepo.overloadCalls, "should NOT pause when disabled")
}
func TestHandle529_NilSettingService_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 20
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
// NOT calling SetSettingService — remains nil
account := &Account{ID: 77, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(20*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_NilSettingService_ZeroConfig_DefaultsTen(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
account := &Account{ID: 88, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(10*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DBReadError_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
errRepo := &errSettingRepo{readErr: context.DeadlineExceeded}
errRepo.data = make(map[string]string)
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 7
settingSvc := NewSettingService(errRepo, cfg)
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 99, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(7*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
// ===========================================================================
// Model: defaults & JSON round-trip
// ===========================================================================
func TestDefaultOverloadCooldownSettings(t *testing.T) {
d := DefaultOverloadCooldownSettings()
require.True(t, d.Enabled)
require.Equal(t, 10, d.CooldownMinutes)
}
func TestOverloadCooldownSettings_JSONRoundTrip(t *testing.T) {
original := OverloadCooldownSettings{Enabled: false, CooldownMinutes: 42}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded OverloadCooldownSettings
require.NoError(t, json.Unmarshal(data, &decoded))
require.Equal(t, original, decoded)
// Verify JSON uses snake_case field names
var raw map[string]any
require.NoError(t, json.Unmarshal(data, &raw))
_, hasEnabled := raw["enabled"]
_, hasCooldown := raw["cooldown_minutes"]
require.True(t, hasEnabled, "JSON must use 'enabled'")
require.True(t, hasCooldown, "JSON must use 'cooldown_minutes'")
}

View File

@@ -1023,11 +1023,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 {
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时
// 根据配置决定是否暂停账号调度及冷却时
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
var settings *OverloadCooldownSettings
if s.settingService != nil {
var err error
settings, err = s.settingService.GetOverloadCooldownSettings(ctx)
if err != nil {
slog.Warn("overload_settings_read_failed", "account_id", account.ID, "error", err)
settings = nil
}
}
// 回退到配置文件
if settings == nil {
cooldown := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldown <= 0 {
cooldown = 10
}
settings = &OverloadCooldownSettings{Enabled: true, CooldownMinutes: cooldown}
}
if !settings.Enabled {
slog.Info("account_529_ignored", "account_id", account.ID, "reason", "overload_cooldown_disabled")
return
}
cooldownMinutes := settings.CooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)

View File

@@ -1172,6 +1172,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOverloadCooldownSettings(), nil
}
return nil, fmt.Errorf("get overload cooldown settings: %w", err)
}
if value == "" {
return DefaultOverloadCooldownSettings(), nil
}
var settings OverloadCooldownSettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultOverloadCooldownSettings(), nil
}
// 修正配置值范围
if settings.CooldownMinutes < 1 {
settings.CooldownMinutes = 1
}
if settings.CooldownMinutes > 120 {
settings.CooldownMinutes = 120
}
return &settings, nil
}
// SetOverloadCooldownSettings 设置529过载冷却配置
func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settings *OverloadCooldownSettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
// 禁用时修正为合法值即可,不拒绝请求
if settings.CooldownMinutes < 1 || settings.CooldownMinutes > 120 {
if settings.Enabled {
return fmt.Errorf("cooldown_minutes must be between 1-120")
}
settings.CooldownMinutes = 10 // 禁用状态下归一化为默认值
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal overload cooldown settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
}
// GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)

View File

@@ -222,6 +222,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// OverloadCooldownSettings 529过载冷却配置
type OverloadCooldownSettings struct {
// Enabled 是否在收到529时暂停账号调度
Enabled bool `json:"enabled"`
// CooldownMinutes 冷却时长(分钟)
CooldownMinutes int `json:"cooldown_minutes"`
}
// DefaultOverloadCooldownSettings 返回默认的过载冷却配置启用10分钟
func DefaultOverloadCooldownSettings() *OverloadCooldownSettings {
return &OverloadCooldownSettings{
Enabled: true,
CooldownMinutes: 10,
}
}
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置
func DefaultBetaPolicySettings() *BetaPolicySettings {
return &BetaPolicySettings{

View File

@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) {
return 0, nil
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil

View File

@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {

View File

@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
}
// List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil {
return nil, nil, err
}

View File

@@ -98,6 +98,9 @@ type UsageLog struct {
AccountID int64
RequestID string
Model string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
// ReasoningEffort is the request's reasoning effort level.

View File

@@ -0,0 +1,21 @@
package service
import "strings"
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and
// differs from compare; otherwise nil. Used to store upstream_model only when
// it differs from the requested model.
func optionalNonEqualStringPtr(value, compare string) *string {
if value == "" || value == compare {
return nil
}
return &value
}

View File

@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error

View File

@@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet(
ProvideIdempotencyCleanupService,
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
NewGroupCapacityService,
)

View File

@@ -247,6 +247,12 @@ func install(c *gin.Context) {
return
}
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
// ========== COMPREHENSIVE INPUT VALIDATION ==========
// Database validation
if !validateHostname(req.Database.Host) {
@@ -319,13 +325,6 @@ func install(c *gin.Context) {
return
}
// Trim whitespace from string inputs
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
cfg := &SetupConfig{
Database: req.Database,
Redis: req.Redis,

View File

@@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
// Inject before </head>
headClose := []byte("</head>")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
result := bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
// Replace <title> with custom site name so the browser tab shows it immediately
result = injectSiteTitle(result, settingsJSON)
return result
}
// injectSiteTitle replaces the static <title> in HTML with the configured site name.
// This ensures the browser tab shows the correct title before JS executes.
func injectSiteTitle(html, settingsJSON []byte) []byte {
var cfg struct {
SiteName string `json:"site_name"`
}
if err := json.Unmarshal(settingsJSON, &cfg); err != nil || cfg.SiteName == "" {
return html
}
// Find and replace the existing <title>...</title>
titleStart := bytes.Index(html, []byte("<title>"))
titleEnd := bytes.Index(html, []byte("</title>"))
if titleStart == -1 || titleEnd == -1 || titleEnd <= titleStart {
return html
}
newTitle := []byte("<title>" + cfg.SiteName + " - AI API Gateway</title>")
var buf bytes.Buffer
buf.Write(html[:titleStart])
buf.Write(newTitle)
buf.Write(html[titleEnd+len("</title>"):])
return buf.Bytes()
}
// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value

View File

@@ -20,6 +20,78 @@ func init() {
gin.SetMode(gin.TestMode)
}
func TestInjectSiteTitle(t *testing.T) {
t.Run("replaces_title_with_site_name", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), "<title>MyCustomSite - AI API Gateway</title>")
assert.NotContains(t, string(result), "Sub2API")
})
t.Run("returns_unchanged_when_site_name_empty", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":""}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_site_name_missing", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"other_field":"value"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_invalid_json", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{invalid json}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_no_title_tag", func(t *testing.T) {
html := []byte(`<html><head></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_title_has_attributes", func(t *testing.T) {
// The function looks for "<title>" literally, so attributes are not supported
// This is acceptable since index.html uses plain <title> without attributes
html := []byte(`<html><head><title lang="en">Sub2API</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"NewSite"}`)
result := injectSiteTitle(html, settingsJSON)
// Should return unchanged since <title> with attributes is not matched
assert.Equal(t, string(html), string(result))
})
t.Run("preserves_rest_of_html", func(t *testing.T) {
html := []byte(`<html><head><meta charset="UTF-8"><title>Sub2API</title><script src="app.js"></script></head><body><div id="app"></div></body></html>`)
settingsJSON := []byte(`{"site_name":"TestSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), `<meta charset="UTF-8">`)
assert.Contains(t, string(result), `<script src="app.js"></script>`)
assert.Contains(t, string(result), `<div id="app"></div>`)
assert.Contains(t, string(result), "<title>TestSite - AI API Gateway</title>")
})
}
func TestReplaceNoncePlaceholder(t *testing.T) {
t.Run("replaces_single_placeholder", func(t *testing.T) {
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`)

View File

@@ -0,0 +1,4 @@
-- Add upstream_model field to usage_logs.
-- Stores the actual upstream model name when it differs from the requested model
-- (i.e., when model mapping is applied). NULL means no mapping was applied.
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);

View File

@@ -0,0 +1,17 @@
-- Map claude-haiku-4-5 variants target from claude-sonnet-4-5 to claude-sonnet-4-6
--
-- Only updates when the current target is exactly claude-sonnet-4-5.
-- 1. claude-haiku-4-5
UPDATE accounts
SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5}', '"claude-sonnet-4-6"')
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping'->>'claude-haiku-4-5' = 'claude-sonnet-4-5';
-- 2. claude-haiku-4-5-20251001
UPDATE accounts
SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5-20251001}', '"claude-sonnet-4-6"')
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping'->>'claude-haiku-4-5-20251001' = 'claude-sonnet-4-5';

View File

@@ -0,0 +1,3 @@
-- Support upstream_model / mapping model distribution aggregations with time-range filters.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model
ON usage_logs (created_at, model, upstream_model);

View File

@@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql`
## Migration File Structure
```sql
-- +goose Up
-- +goose StatementBegin
-- Your forward migration SQL here
-- +goose StatementEnd
This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is.
-- +goose Down
-- +goose StatementBegin
-- Your rollback migration SQL here
-- +goose StatementEnd
- Regular migrations (`*.sql`): executed in a transaction.
- Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`).
```sql
-- Forward-only migration (recommended)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100);
```
> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file.
## Important Rules
### ⚠️ Immutability Principle
@@ -66,9 +66,9 @@ Why?
touch migrations/018_your_change.sql
```
2. **Write Up and Down migrations**
- Up: Apply the change
- Down: Revert the change (should be symmetric with Up)
2. **Write forward-only migration SQL**
- Put only the intended schema change in the file
- If rollback is needed, create a new migration file to revert
3. **Test locally**
```bash
@@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql
## Example Migration
```sql
-- +goose Up
-- +goose StatementBegin
-- Add tier_id field to Gemini OAuth accounts for quota tracking
UPDATE accounts
SET credentials = jsonb_set(
@@ -157,17 +155,6 @@ SET credentials = jsonb_set(
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' IS NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Remove tier_id field
UPDATE accounts
SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' = 'LEGACY';
-- +goose StatementEnd
```
## Troubleshooting
@@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW());
## References
- Migration runner: `internal/repository/migrations_runner.go`
- Goose syntax: https://github.com/pressly/goose
- PostgreSQL docs: https://www.postgresql.org/docs/

View File

@@ -38,7 +38,7 @@ services:
- ./data:/app/data
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml:ro
# - ./config.yaml:/app/data/config.yaml
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)

View File

@@ -30,7 +30,7 @@ services:
- sub2api_data:/app/data
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml:ro
# - ./config.yaml:/app/data/config.yaml
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)

View File

@@ -6,7 +6,8 @@ set -e
# preventing the non-root sub2api user from writing files.
if [ "$(id -u)" = "0" ]; then
mkdir -p /app/data
chown -R sub2api:sub2api /app/data
# Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro)
chown -R sub2api:sub2api /app/data 2>/dev/null || true
# Re-invoke this script as sub2api so the flag-detection below
# also runs under the correct user.
exec su-exec sub2api "$0" "$@"

View File

@@ -3,6 +3,7 @@ import { RouterView, useRouter, useRoute } from 'vue-router'
import { onMounted, onBeforeUnmount, watch } from 'vue'
import Toast from '@/components/common/Toast.vue'
import NavigationProgress from '@/components/common/NavigationProgress.vue'
import { resolveDocumentTitle } from '@/router/title'
import AnnouncementPopup from '@/components/common/AnnouncementPopup.vue'
import { useAppStore, useAuthStore, useSubscriptionStore, useAnnouncementStore } from '@/stores'
import { getSetupStatus } from '@/api/setup'
@@ -104,6 +105,9 @@ onMounted(async () => {
// Load public settings into appStore (will be cached for other components)
await appStore.fetchPublicSettings()
// Re-resolve document title now that siteName is available
document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string)
})
</script>

View File

@@ -81,6 +81,7 @@ export interface ModelStatsParams {
user_id?: number
api_key_id?: number
model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
account_id?: number
group_id?: number
request_type?: UsageRequestType
@@ -162,6 +163,7 @@ export interface UserBreakdownParams {
end_date?: string
group_id?: number
model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number

View File

@@ -218,6 +218,34 @@ export async function batchSetGroupRateMultipliers(
return data
}
/**
* Get usage summary (today + cumulative cost) for all groups
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
* @returns Array of group usage summaries
*/
export async function getUsageSummary(
timezone?: string
): Promise<{ group_id: number; today_cost: number; total_cost: number }[]> {
const { data } = await apiClient.get<
{ group_id: number; today_cost: number; total_cost: number }[]
>('/admin/groups/usage-summary', {
params: timezone ? { timezone } : undefined
})
return data
}
/**
* Get capacity summary (concurrency/sessions/RPM) for all active groups
*/
export async function getCapacitySummary(): Promise<
{ group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[]
> {
const { data } = await apiClient.get<
{ group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[]
>('/admin/groups/capacity-summary')
return data
}
export const groupsAPI = {
list,
getAll,
@@ -232,7 +260,9 @@ export const groupsAPI = {
getGroupRateMultipliers,
clearGroupRateMultipliers,
batchSetGroupRateMultipliers,
updateSortOrder
updateSortOrder,
getUsageSummary,
getCapacitySummary
}
export default groupsAPI

View File

@@ -242,6 +242,33 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
return data
}
// ==================== Overload Cooldown Settings ====================
/**
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
enabled: boolean
cooldown_minutes: number
}
export async function getOverloadCooldownSettings(): Promise<OverloadCooldownSettings> {
const { data } = await apiClient.get<OverloadCooldownSettings>('/admin/settings/overload-cooldown')
return data
}
export async function updateOverloadCooldownSettings(
settings: OverloadCooldownSettings
): Promise<OverloadCooldownSettings> {
const { data } = await apiClient.put<OverloadCooldownSettings>(
'/admin/settings/overload-cooldown',
settings
)
return data
}
// ==================== Stream Timeout Settings ====================
/**
* Stream timeout settings interface
*/
@@ -499,6 +526,8 @@ export const settingsAPI = {
getAdminApiKey,
regenerateAdminApiKey,
deleteAdminApiKey,
getOverloadCooldownSettings,
updateOverloadCooldownSettings,
getStreamTimeoutSettings,
updateStreamTimeoutSettings,
getRectifierSettings,

View File

@@ -27,6 +27,7 @@ export async function list(
status?: 'active' | 'expired' | 'revoked'
user_id?: number
group_id?: number
platform?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
},

View File

@@ -82,6 +82,7 @@
:utilization="usageInfo.five_hour.utilization"
:resets-at="usageInfo.five_hour.resets_at"
:window-stats="usageInfo.five_hour.window_stats"
:show-now-when-idle="true"
color="indigo"
/>
<UsageProgressBar
@@ -90,6 +91,7 @@
:utilization="usageInfo.seven_day.utilization"
:resets-at="usageInfo.seven_day.resets_at"
:window-stats="usageInfo.seven_day.window_stats"
:show-now-when-idle="true"
color="emerald"
/>
</div>

View File

@@ -48,7 +48,7 @@
</span>
<!-- Reset time -->
<span v-if="resetsAt" class="shrink-0 text-[10px] text-gray-400">
<span v-if="shouldShowResetTime" class="shrink-0 text-[10px] text-gray-400">
{{ formatResetTime }}
</span>
</div>
@@ -68,6 +68,7 @@ const props = defineProps<{
resetsAt?: string | null
color: 'indigo' | 'emerald' | 'purple' | 'amber'
windowStats?: WindowStats | null
showNowWhenIdle?: boolean
}>()
const { t } = useI18n()
@@ -139,9 +140,20 @@ const displayPercent = computed(() => {
return percent > 999 ? '>999%' : `${percent}%`
})
const shouldShowResetTime = computed(() => {
if (props.resetsAt) return true
return Boolean(props.showNowWhenIdle && props.utilization <= 0)
})
// Format reset time
const formatResetTime = computed(() => {
// For rolling windows, when utilization is 0%, treat as immediately available.
if (props.showNowWhenIdle && props.utilization <= 0) {
return '现在'
}
if (!props.resetsAt) return '-'
const date = new Date(props.resetsAt)
const diffMs = date.getTime() - now.value.getTime()

View File

@@ -0,0 +1,69 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import UsageProgressBar from '../UsageProgressBar.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('UsageProgressBar', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.setSystemTime(new Date('2026-03-17T00:00:00Z'))
})
afterEach(() => {
vi.useRealTimers()
})
it('showNowWhenIdle=true 且利用率为 0 时显示“现在”', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '5h',
utilization: 0,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: true,
color: 'indigo'
}
})
expect(wrapper.text()).toContain('现在')
expect(wrapper.text()).not.toContain('2h 30m')
})
it('showNowWhenIdle=true 但利用率大于 0 时显示倒计时', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '7d',
utilization: 12,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: true,
color: 'emerald'
}
})
expect(wrapper.text()).toContain('2h 30m')
expect(wrapper.text()).not.toContain('现在')
})
it('showNowWhenIdle=false 时保持原有倒计时行为', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '1d',
utilization: 0,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: false,
color: 'indigo'
}
})
expect(wrapper.text()).toContain('2h 30m')
expect(wrapper.text()).not.toContain('现在')
})
})

View File

@@ -25,8 +25,16 @@
<span class="text-sm text-gray-900 dark:text-white">{{ row.account?.name || '-' }}</span>
</template>
<template #cell-model="{ value }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
<template #cell-model="{ row }">
<div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div class="break-all font-medium text-gray-900 dark:text-white">
{{ row.model }}
</div>
<div class="break-all text-gray-500 dark:text-gray-400">
<span class="mr-0.5"></span>{{ row.upstream_model }}
</div>
</div>
<span v-else class="font-medium text-gray-900 dark:text-white">{{ row.model }}</span>
</template>
<template #cell-reasoning_effort="{ row }">

View File

@@ -1,10 +1,10 @@
<template>
<div class="card p-4">
<div class="mb-4 flex items-start justify-between gap-3">
<div class="mb-4 flex items-center justify-between gap-3">
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
{{ title || t('usage.endpointDistribution') }}
</h3>
<div class="flex flex-col items-end gap-2">
<div class="flex flex-wrap items-center justify-end gap-2">
<div
v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"

View File

@@ -6,7 +6,42 @@
? t('admin.dashboard.modelDistribution')
: t('admin.dashboard.spendingRankingTitle') }}
</h3>
<div class="flex items-center gap-2">
<div class="flex flex-wrap items-center justify-end gap-2">
<div
v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'requested'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'requested')"
>
{{ t('usage.requestedModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'upstream'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'upstream')"
>
{{ t('usage.upstreamModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'mapping'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'mapping')"
>
{{ t('usage.mapping') }}
</button>
</div>
<div
v-if="showMetricToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
@@ -215,9 +250,13 @@ ChartJS.register(ArcElement, Tooltip, Legend)
const { t } = useI18n()
type DistributionMetric = 'tokens' | 'actual_cost'
type ModelSource = 'requested' | 'upstream' | 'mapping'
type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean }
const props = withDefaults(defineProps<{
modelStats: ModelStat[]
upstreamModelStats?: ModelStat[]
mappingModelStats?: ModelStat[]
source?: ModelSource
enableRankingView?: boolean
rankingItems?: UserSpendingRankingItem[]
rankingTotalActualCost?: number
@@ -225,12 +264,16 @@ const props = withDefaults(defineProps<{
rankingTotalTokens?: number
loading?: boolean
metric?: DistributionMetric
showSourceToggle?: boolean
showMetricToggle?: boolean
rankingLoading?: boolean
rankingError?: boolean
startDate?: string
endDate?: string
}>(), {
upstreamModelStats: () => [],
mappingModelStats: () => [],
source: 'requested',
enableRankingView: false,
rankingItems: () => [],
rankingTotalActualCost: 0,
@@ -238,6 +281,7 @@ const props = withDefaults(defineProps<{
rankingTotalTokens: 0,
loading: false,
metric: 'tokens',
showSourceToggle: false,
showMetricToggle: false,
rankingLoading: false,
rankingError: false
@@ -261,6 +305,7 @@ const toggleBreakdown = async (type: string, id: string) => {
start_date: props.startDate,
end_date: props.endDate,
model: id,
model_source: props.source,
})
breakdownItems.value = res.users || []
} catch {
@@ -272,6 +317,7 @@ const toggleBreakdown = async (type: string, id: string) => {
const emit = defineEmits<{
'update:metric': [value: DistributionMetric]
'update:source': [value: ModelSource]
'ranking-click': [item: UserSpendingRankingItem]
}>()
@@ -294,14 +340,19 @@ const chartColors = [
]
const displayModelStats = computed(() => {
if (!props.modelStats?.length) return []
const sourceStats = props.source === 'upstream'
? props.upstreamModelStats
: props.source === 'mapping'
? props.mappingModelStats
: props.modelStats
if (!sourceStats?.length) return []
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
return [...props.modelStats].sort((a, b) => b[metricKey] - a[metricKey])
return [...sourceStats].sort((a, b) => b[metricKey] - a[metricKey])
})
const chartData = computed(() => {
if (!props.modelStats?.length) return null
if (!displayModelStats.value.length) return null
return {
labels: displayModelStats.value.map((m) => m.model),

View File

@@ -0,0 +1,84 @@
<template>
<div class="flex flex-col gap-1">
<!-- 并发槽位 -->
<div class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(concurrencyUsed, concurrencyMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6A2.25 2.25 0 016 3.75h2.25A2.25 2.25 0 0110.5 6v2.25a2.25 2.25 0 01-2.25 2.25H6a2.25 2.25 0 01-2.25-2.25V6zM3.75 15.75A2.25 2.25 0 016 13.5h2.25a2.25 2.25 0 012.25 2.25V18a2.25 2.25 0 01-2.25 2.25H6A2.25 2.25 0 013.75 18v-2.25zM13.5 6a2.25 2.25 0 012.25-2.25H18A2.25 2.25 0 0120.25 6v2.25A2.25 2.25 0 0118 10.5h-2.25a2.25 2.25 0 01-2.25-2.25V6zM13.5 15.75a2.25 2.25 0 012.25-2.25H18a2.25 2.25 0 012.25 2.25V18A2.25 2.25 0 0118 20.25h-2.25A2.25 2.25 0 0113.5 18v-2.25z" />
</svg>
<span class="font-mono">{{ concurrencyUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ concurrencyMax }}</span>
</span>
</div>
<!-- 会话数 -->
<div v-if="sessionsMax > 0" class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(sessionsUsed, sessionsMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M15 19.128a9.38 9.38 0 002.625.372 9.337 9.337 0 004.121-.952 4.125 4.125 0 00-7.533-2.493M15 19.128v-.003c0-1.113-.285-2.16-.786-3.07M15 19.128v.106A12.318 12.318 0 018.624 21c-2.331 0-4.512-.645-6.374-1.766l-.001-.109a6.375 6.375 0 0111.964-3.07M12 6.375a3.375 3.375 0 11-6.75 0 3.375 3.375 0 016.75 0zm8.25 2.25a2.625 2.625 0 11-5.25 0 2.625 2.625 0 015.25 0z" />
</svg>
<span class="font-mono">{{ sessionsUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ sessionsMax }}</span>
</span>
</div>
<!-- RPM -->
<div v-if="rpmMax > 0" class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(rpmUsed, rpmMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 6v6h4.5m4.5 0a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />
</svg>
<span class="font-mono">{{ rpmUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ rpmMax }}</span>
</span>
</div>
</div>
</template>
<script setup lang="ts">
interface Props {
concurrencyUsed: number
concurrencyMax: number
sessionsUsed: number
sessionsMax: number
rpmUsed: number
rpmMax: number
}
withDefaults(defineProps<Props>(), {
concurrencyUsed: 0,
concurrencyMax: 0,
sessionsUsed: 0,
sessionsMax: 0,
rpmUsed: 0,
rpmMax: 0
})
function capacityClass(used: number, max: number): string {
if (max > 0 && used >= max) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
if (used > 0) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
return 'bg-gray-100 text-gray-600 dark:bg-gray-800 dark:text-gray-400'
}
</script>

View File

@@ -218,7 +218,7 @@ export default {
email: 'Email',
password: 'Password',
confirmPassword: 'Confirm Password',
passwordPlaceholder: 'Min 6 characters',
passwordPlaceholder: 'Min 8 characters',
confirmPasswordPlaceholder: 'Confirm password',
passwordMismatch: 'Passwords do not match'
},
@@ -718,11 +718,14 @@ export default {
exporting: 'Exporting...',
preparingExport: 'Preparing export...',
model: 'Model',
requestedModel: 'Requested',
upstreamModel: 'Upstream',
reasoningEffort: 'Reasoning Effort',
endpoint: 'Endpoint',
endpointDistribution: 'Endpoint Distribution',
inbound: 'Inbound',
upstream: 'Upstream',
mapping: 'Mapping',
path: 'Path',
inboundEndpoint: 'Inbound Endpoint',
upstreamEndpoint: 'Upstream Endpoint',
@@ -1505,6 +1508,8 @@ export default {
rateMultiplier: 'Rate Multiplier',
type: 'Type',
accounts: 'Accounts',
capacity: 'Capacity',
usage: 'Usage',
status: 'Status',
actions: 'Actions',
billingType: 'Billing Type',
@@ -1513,6 +1518,12 @@ export default {
userNotes: 'Notes',
userStatus: 'Status'
},
usageToday: 'Today',
usageTotal: 'Total',
accountsAvailable: 'Avail:',
accountsRateLimited: 'Limited:',
accountsTotal: 'Total:',
accountsUnit: '',
rateAndAccounts: '{rate}x rate · {count} accounts',
accountsCount: '{count} accounts',
form: {
@@ -1694,6 +1705,7 @@ export default {
revokeSubscription: 'Revoke Subscription',
allStatus: 'All Status',
allGroups: 'All Groups',
allPlatforms: 'All Platforms',
daily: 'Daily',
weekly: 'Weekly',
monthly: 'Monthly',
@@ -1759,7 +1771,37 @@ export default {
pleaseSelectGroup: 'Please select a group',
validityDaysRequired: 'Please enter a valid number of days (at least 1)',
revokeConfirm:
"Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone."
"Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
guide: {
title: 'Subscription Management Guide',
subtitle: 'Subscription mode lets you assign time-based usage quotas to users, with daily/weekly/monthly limits. Follow these steps to get started.',
showGuide: 'Usage Guide',
step1: {
title: 'Create a Subscription Group',
line1: 'Go to "Group Management" page, click "Create Group"',
line2: 'Set billing type to "Subscription", configure daily/weekly/monthly quota limits',
line3: 'Save the group and ensure its status is "Active"',
link: 'Go to Group Management'
},
step2: {
title: 'Assign Subscription to User',
line1: 'Click the "Assign Subscription" button in the top right',
line2: 'Search for a user by email and select them',
line3: 'Choose a subscription group, set validity days, then click "Assign"'
},
step3: {
title: 'Manage Existing Subscriptions'
},
actions: {
adjust: 'Adjust',
adjustDesc: 'Extend or shorten the subscription validity period',
resetQuota: 'Reset Quota',
resetQuotaDesc: 'Reset daily/weekly/monthly usage to zero',
revoke: 'Revoke',
revokeDesc: 'Immediately terminate the subscription (irreversible)'
},
tip: 'Tip: Only groups with billing type "Subscription" and status "Active" appear in the group dropdown. If no options are available, create one in Group Management first.'
}
},
// Accounts
@@ -4320,6 +4362,16 @@ export default {
testFailed: 'Google Drive storage test failed'
}
},
overloadCooldown: {
title: '529 Overload Cooldown',
description: 'Configure account scheduling pause strategy when upstream returns 529 (overloaded)',
enabled: 'Enable Overload Cooldown',
enabledHint: 'Pause account scheduling on 529 errors, auto-recover after cooldown',
cooldownMinutes: 'Cooldown Duration (minutes)',
cooldownMinutesHint: 'Duration to pause account scheduling (1-120 minutes)',
saved: 'Overload cooldown settings saved',
saveFailed: 'Failed to save overload cooldown settings'
},
streamTimeout: {
title: 'Stream Timeout Handling',
description: 'Configure account handling strategy when upstream response times out',

Some files were not shown because too many files have changed in this diff Show More