mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 17:00:20 +08:00
Compare commits
130 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f39754812 | ||
|
|
ac4371fa98 | ||
|
|
9985c4a344 | ||
|
|
804b6f2282 | ||
|
|
cb58daf38d | ||
|
|
a9398d210b | ||
|
|
df1c2383da | ||
|
|
ff8b1b4ae3 | ||
|
|
4cce21b125 | ||
|
|
6baf810885 | ||
|
|
9a48b2e942 | ||
|
|
c0c9c984d1 | ||
|
|
3fed478e4d | ||
|
|
0afc5d0b1a | ||
|
|
ba5a0d47eb | ||
|
|
be7bc658fc | ||
|
|
c89bbf5130 | ||
|
|
e59e3a9f00 | ||
|
|
6146be1474 | ||
|
|
730d2a9ad2 | ||
|
|
d008941cb3 | ||
|
|
df7a3e65ee | ||
|
|
0707f3d963 | ||
|
|
976d6fb03f | ||
|
|
f1aafbc06f | ||
|
|
7cb5444dbb | ||
|
|
3bede6e65f | ||
|
|
ad90bb4645 | ||
|
|
2220fd18ca | ||
|
|
bb3df5785a | ||
|
|
6e54eda41f | ||
|
|
df4c0adf0b | ||
|
|
7cbe4afdb8 | ||
|
|
16f150caae | ||
|
|
7229b41fc7 | ||
|
|
2cd5037878 | ||
|
|
53ee6383db | ||
|
|
a09478f374 | ||
|
|
d3c1d77a35 | ||
|
|
8824400c3e | ||
|
|
6e8eff9bb9 | ||
|
|
f5884d1608 | ||
|
|
56949a58bc | ||
|
|
7d256879c5 | ||
|
|
f9512fda58 | ||
|
|
beb63cb152 | ||
|
|
11ff73b578 | ||
|
|
0ed4a404e4 | ||
|
|
6c86501d11 | ||
|
|
2fe8932c1d | ||
|
|
0ab68aa9fb | ||
|
|
2f92b06869 | ||
|
|
03e94f9f53 | ||
|
|
606e29d390 | ||
|
|
3ecadf4aad | ||
|
|
0170d19fa7 | ||
|
|
ce1d2904c7 | ||
|
|
ea41f830fd | ||
|
|
e1a4a7b8c0 | ||
|
|
b381e8ee73 | ||
|
|
45e1429ae8 | ||
|
|
adb77af1d9 | ||
|
|
3a34746668 | ||
|
|
fe17058700 | ||
|
|
602bf9c017 | ||
|
|
7ade9baa15 | ||
|
|
fa454b1b99 | ||
|
|
8375094c69 | ||
|
|
91079d3f15 | ||
|
|
63412a9fcc | ||
|
|
d98648f03b | ||
|
|
c37fe91672 | ||
|
|
4d40fb6b60 | ||
|
|
be3b788b8f | ||
|
|
723e54013a | ||
|
|
4d566f68b6 | ||
|
|
31f817d189 | ||
|
|
59231668c5 | ||
|
|
5b787334c8 | ||
|
|
f761afb1ef | ||
|
|
877c17251d | ||
|
|
ffe43f6098 | ||
|
|
66f49b67d6 | ||
|
|
08d6dc5227 | ||
|
|
7cea6b6fc9 | ||
|
|
4b57e80e6a | ||
|
|
a161fcc89b | ||
|
|
e316a923d4 | ||
|
|
fd0370c07a | ||
|
|
316f2fee21 | ||
|
|
3002c7a17f | ||
|
|
207e09500a | ||
|
|
52c745bc62 | ||
|
|
498c6cfae9 | ||
|
|
71f8b9e473 | ||
|
|
3a31fa4768 | ||
|
|
65e69738cc | ||
|
|
549c134bb8 | ||
|
|
d206721fc1 | ||
|
|
64795a03e3 | ||
|
|
c8e2f614fa | ||
|
|
86d63f919d | ||
|
|
c43aa22cdb | ||
|
|
d1a6303e49 | ||
|
|
c0347cde85 | ||
|
|
2f2e76f9c6 | ||
|
|
dd7f21244b | ||
|
|
49be9d08f3 | ||
|
|
bba5b3c037 | ||
|
|
26298c4a5f | ||
|
|
eb7d830296 | ||
|
|
02db4c7671 | ||
|
|
a05b8b56e3 | ||
|
|
eca3898410 | ||
|
|
6901b64fce | ||
|
|
32c47b1509 | ||
|
|
39e430018b | ||
|
|
6549a40cf4 | ||
|
|
4e75d8fda9 | ||
|
|
8917a3ea8f | ||
|
|
0c011b889b | ||
|
|
0962ba43c0 | ||
|
|
b8c48fb477 | ||
|
|
2a7d04fec4 | ||
|
|
bd854e1750 | ||
|
|
65fd0d15ae | ||
|
|
c11f14f3a0 | ||
|
|
98b65e67f2 | ||
|
|
c579439c1e | ||
|
|
46e5ac9672 |
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.25.5-alpine
|
FROM golang:1.25.6-alpine
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ RUN go mod download
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# 构建应用
|
# 构建应用
|
||||||
RUN go build -o main cmd/server/main.go
|
RUN go build -o main ./cmd/server/
|
||||||
|
|
||||||
# 暴露端口
|
# 暴露端口
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||||
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userRepository := repository.NewUserRepository(client, db)
|
userRepository := repository.NewUserRepository(client, db)
|
||||||
|
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||||
settingRepository := repository.NewSettingRepository(client)
|
settingRepository := repository.NewSettingRepository(client)
|
||||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||||
redisClient := repository.ProvideRedis(configConfig)
|
redisClient := repository.ProvideRedis(configConfig)
|
||||||
@@ -61,24 +62,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||||
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||||
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
totpCache := repository.NewTotpCache(redisClient)
|
totpCache := repository.NewTotpCache(redisClient)
|
||||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
announcementRepository := repository.NewAnnouncementRepository(client)
|
announcementRepository := repository.NewAnnouncementRepository(client)
|
||||||
@@ -173,8 +173,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ type APIKey struct {
|
|||||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
// Blocked IPs/CIDRs
|
// Blocked IPs/CIDRs
|
||||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
|
// Quota limit in USD for this API key (0 = unlimited)
|
||||||
|
Quota float64 `json:"quota,omitempty"`
|
||||||
|
// Used quota amount in USD
|
||||||
|
QuotaUsed float64 `json:"quota_used,omitempty"`
|
||||||
|
// Expiration time for this API key (null = never expires)
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
||||||
Edges APIKeyEdges `json:"edges"`
|
Edges APIKeyEdges `json:"edges"`
|
||||||
@@ -97,11 +103,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
|
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
|
case apikey.FieldQuota, apikey.FieldQuotaUsed:
|
||||||
|
values[i] = new(sql.NullFloat64)
|
||||||
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt:
|
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
default:
|
default:
|
||||||
values[i] = new(sql.UnknownType)
|
values[i] = new(sql.UnknownType)
|
||||||
@@ -190,6 +198,25 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
|
|||||||
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
|
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field quota", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Quota = value.Float64
|
||||||
|
}
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field quota_used", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.QuotaUsed = value.Float64
|
||||||
|
}
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ExpiresAt = new(time.Time)
|
||||||
|
*_m.ExpiresAt = value.Time
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -274,6 +301,17 @@ func (_m *APIKey) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("ip_blacklist=")
|
builder.WriteString("ip_blacklist=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
|
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("quota=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.Quota))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("quota_used=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.QuotaUsed))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.ExpiresAt; v != nil {
|
||||||
|
builder.WriteString("expires_at=")
|
||||||
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
}
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,12 @@ const (
|
|||||||
FieldIPWhitelist = "ip_whitelist"
|
FieldIPWhitelist = "ip_whitelist"
|
||||||
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
|
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
|
||||||
FieldIPBlacklist = "ip_blacklist"
|
FieldIPBlacklist = "ip_blacklist"
|
||||||
|
// FieldQuota holds the string denoting the quota field in the database.
|
||||||
|
FieldQuota = "quota"
|
||||||
|
// FieldQuotaUsed holds the string denoting the quota_used field in the database.
|
||||||
|
FieldQuotaUsed = "quota_used"
|
||||||
|
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||||
|
FieldExpiresAt = "expires_at"
|
||||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||||
EdgeUser = "user"
|
EdgeUser = "user"
|
||||||
// EdgeGroup holds the string denoting the group edge name in mutations.
|
// EdgeGroup holds the string denoting the group edge name in mutations.
|
||||||
@@ -79,6 +85,9 @@ var Columns = []string{
|
|||||||
FieldStatus,
|
FieldStatus,
|
||||||
FieldIPWhitelist,
|
FieldIPWhitelist,
|
||||||
FieldIPBlacklist,
|
FieldIPBlacklist,
|
||||||
|
FieldQuota,
|
||||||
|
FieldQuotaUsed,
|
||||||
|
FieldExpiresAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||||
@@ -113,6 +122,10 @@ var (
|
|||||||
DefaultStatus string
|
DefaultStatus string
|
||||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
StatusValidator func(string) error
|
StatusValidator func(string) error
|
||||||
|
// DefaultQuota holds the default value on creation for the "quota" field.
|
||||||
|
DefaultQuota float64
|
||||||
|
// DefaultQuotaUsed holds the default value on creation for the "quota_used" field.
|
||||||
|
DefaultQuotaUsed float64
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the APIKey queries.
|
// OrderOption defines the ordering options for the APIKey queries.
|
||||||
@@ -163,6 +176,21 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByQuota orders the results by the quota field.
|
||||||
|
func ByQuota(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldQuota, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByQuotaUsed orders the results by the quota_used field.
|
||||||
|
func ByQuotaUsed(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldQuotaUsed, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByExpiresAt orders the results by the expires_at field.
|
||||||
|
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByUserField orders the results by user field.
|
// ByUserField orders the results by user field.
|
||||||
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -95,6 +95,21 @@ func Status(v string) predicate.APIKey {
|
|||||||
return predicate.APIKey(sql.FieldEQ(FieldStatus, v))
|
return predicate.APIKey(sql.FieldEQ(FieldStatus, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ.
|
||||||
|
func Quota(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsed applies equality check predicate on the "quota_used" field. It's identical to QuotaUsedEQ.
|
||||||
|
func QuotaUsed(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||||
|
func ExpiresAt(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.APIKey {
|
func CreatedAtEQ(v time.Time) predicate.APIKey {
|
||||||
return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -490,6 +505,136 @@ func IPBlacklistNotNil() predicate.APIKey {
|
|||||||
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
|
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QuotaEQ applies the EQ predicate on the "quota" field.
|
||||||
|
func QuotaEQ(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaNEQ applies the NEQ predicate on the "quota" field.
|
||||||
|
func QuotaNEQ(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNEQ(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaIn applies the In predicate on the "quota" field.
|
||||||
|
func QuotaIn(vs ...float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIn(FieldQuota, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaNotIn applies the NotIn predicate on the "quota" field.
|
||||||
|
func QuotaNotIn(vs ...float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotIn(FieldQuota, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaGT applies the GT predicate on the "quota" field.
|
||||||
|
func QuotaGT(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGT(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaGTE applies the GTE predicate on the "quota" field.
|
||||||
|
func QuotaGTE(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGTE(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaLT applies the LT predicate on the "quota" field.
|
||||||
|
func QuotaLT(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLT(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaLTE applies the LTE predicate on the "quota" field.
|
||||||
|
func QuotaLTE(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLTE(FieldQuota, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedEQ applies the EQ predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedEQ(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedNEQ applies the NEQ predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedNEQ(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNEQ(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedIn applies the In predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedIn(vs ...float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIn(FieldQuotaUsed, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedNotIn applies the NotIn predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedNotIn(vs ...float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotIn(FieldQuotaUsed, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedGT applies the GT predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedGT(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGT(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedGTE applies the GTE predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedGTE(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGTE(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedLT applies the LT predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedLT(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLT(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsedLTE applies the LTE predicate on the "quota_used" field.
|
||||||
|
func QuotaUsedLTE(v float64) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLTE(FieldQuotaUsed, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtEQ(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtNEQ(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtIn(vs ...time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIn(FieldExpiresAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtNotIn(vs ...time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtGT(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGT(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtGTE(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldGTE(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtLT(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLT(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtLTE(v time.Time) predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldLTE(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtIsNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIsNull(FieldExpiresAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtNotNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt))
|
||||||
|
}
|
||||||
|
|
||||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||||
func HasUser() predicate.APIKey {
|
func HasUser() predicate.APIKey {
|
||||||
return predicate.APIKey(func(s *sql.Selector) {
|
return predicate.APIKey(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -125,6 +125,48 @@ func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (_c *APIKeyCreate) SetQuota(v float64) *APIKeyCreate {
|
||||||
|
_c.mutation.SetQuota(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuota sets the "quota" field if the given value is not nil.
|
||||||
|
func (_c *APIKeyCreate) SetNillableQuota(v *float64) *APIKeyCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetQuota(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (_c *APIKeyCreate) SetQuotaUsed(v float64) *APIKeyCreate {
|
||||||
|
_c.mutation.SetQuotaUsed(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
|
||||||
|
func (_c *APIKeyCreate) SetNillableQuotaUsed(v *float64) *APIKeyCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetQuotaUsed(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (_c *APIKeyCreate) SetExpiresAt(v time.Time) *APIKeyCreate {
|
||||||
|
_c.mutation.SetExpiresAt(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||||
|
func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetExpiresAt(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
||||||
return _c.SetUserID(v.ID)
|
return _c.SetUserID(v.ID)
|
||||||
@@ -205,6 +247,14 @@ func (_c *APIKeyCreate) defaults() error {
|
|||||||
v := apikey.DefaultStatus
|
v := apikey.DefaultStatus
|
||||||
_c.mutation.SetStatus(v)
|
_c.mutation.SetStatus(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.Quota(); !ok {
|
||||||
|
v := apikey.DefaultQuota
|
||||||
|
_c.mutation.SetQuota(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.QuotaUsed(); !ok {
|
||||||
|
v := apikey.DefaultQuotaUsed
|
||||||
|
_c.mutation.SetQuotaUsed(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +293,12 @@ func (_c *APIKeyCreate) check() error {
|
|||||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)}
|
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.Quota(); !ok {
|
||||||
|
return &ValidationError{Name: "quota", err: errors.New(`ent: missing required field "APIKey.quota"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.QuotaUsed(); !ok {
|
||||||
|
return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)}
|
||||||
|
}
|
||||||
if len(_c.mutation.UserIDs()) == 0 {
|
if len(_c.mutation.UserIDs()) == 0 {
|
||||||
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)}
|
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)}
|
||||||
}
|
}
|
||||||
@@ -305,6 +361,18 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||||
_node.IPBlacklist = value
|
_node.IPBlacklist = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.Quota(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
|
||||||
|
_node.Quota = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.QuotaUsed(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
|
||||||
|
_node.QuotaUsed = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.ExpiresAt(); ok {
|
||||||
|
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
|
||||||
|
_node.ExpiresAt = &value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
@@ -539,6 +607,60 @@ func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (u *APIKeyUpsert) SetQuota(v float64) *APIKeyUpsert {
|
||||||
|
u.Set(apikey.FieldQuota, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuota sets the "quota" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsert) UpdateQuota() *APIKeyUpsert {
|
||||||
|
u.SetExcluded(apikey.FieldQuota)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds v to the "quota" field.
|
||||||
|
func (u *APIKeyUpsert) AddQuota(v float64) *APIKeyUpsert {
|
||||||
|
u.Add(apikey.FieldQuota, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsert) SetQuotaUsed(v float64) *APIKeyUpsert {
|
||||||
|
u.Set(apikey.FieldQuotaUsed, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsert) UpdateQuotaUsed() *APIKeyUpsert {
|
||||||
|
u.SetExcluded(apikey.FieldQuotaUsed)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds v to the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsert) AddQuotaUsed(v float64) *APIKeyUpsert {
|
||||||
|
u.Add(apikey.FieldQuotaUsed, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsert) SetExpiresAt(v time.Time) *APIKeyUpsert {
|
||||||
|
u.Set(apikey.FieldExpiresAt, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsert) UpdateExpiresAt() *APIKeyUpsert {
|
||||||
|
u.SetExcluded(apikey.FieldExpiresAt)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert {
|
||||||
|
u.SetNull(apikey.FieldExpiresAt)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -738,6 +860,69 @@ func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (u *APIKeyUpsertOne) SetQuota(v float64) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetQuota(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds v to the "quota" field.
|
||||||
|
func (u *APIKeyUpsertOne) AddQuota(v float64) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.AddQuota(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuota sets the "quota" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertOne) UpdateQuota() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateQuota()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsertOne) SetQuotaUsed(v float64) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetQuotaUsed(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds v to the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsertOne) AddQuotaUsed(v float64) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.AddQuotaUsed(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertOne) UpdateQuotaUsed() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateQuotaUsed()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsertOne) SetExpiresAt(v time.Time) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetExpiresAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertOne) UpdateExpiresAt() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateExpiresAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearExpiresAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -1103,6 +1288,69 @@ func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (u *APIKeyUpsertBulk) SetQuota(v float64) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetQuota(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds v to the "quota" field.
|
||||||
|
func (u *APIKeyUpsertBulk) AddQuota(v float64) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.AddQuota(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuota sets the "quota" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertBulk) UpdateQuota() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateQuota()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsertBulk) SetQuotaUsed(v float64) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetQuotaUsed(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds v to the "quota_used" field.
|
||||||
|
func (u *APIKeyUpsertBulk) AddQuotaUsed(v float64) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.AddQuotaUsed(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertBulk) UpdateQuotaUsed() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateQuotaUsed()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsertBulk) SetExpiresAt(v time.Time) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetExpiresAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertBulk) UpdateExpiresAt() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateExpiresAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearExpiresAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -170,6 +170,68 @@ func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (_u *APIKeyUpdate) SetQuota(v float64) *APIKeyUpdate {
|
||||||
|
_u.mutation.ResetQuota()
|
||||||
|
_u.mutation.SetQuota(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuota sets the "quota" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdate) SetNillableQuota(v *float64) *APIKeyUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetQuota(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds value to the "quota" field.
|
||||||
|
func (_u *APIKeyUpdate) AddQuota(v float64) *APIKeyUpdate {
|
||||||
|
_u.mutation.AddQuota(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (_u *APIKeyUpdate) SetQuotaUsed(v float64) *APIKeyUpdate {
|
||||||
|
_u.mutation.ResetQuotaUsed()
|
||||||
|
_u.mutation.SetQuotaUsed(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdate) SetNillableQuotaUsed(v *float64) *APIKeyUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetQuotaUsed(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds value to the "quota_used" field.
|
||||||
|
func (_u *APIKeyUpdate) AddQuotaUsed(v float64) *APIKeyUpdate {
|
||||||
|
_u.mutation.AddQuotaUsed(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (_u *APIKeyUpdate) SetExpiresAt(v time.Time) *APIKeyUpdate {
|
||||||
|
_u.mutation.SetExpiresAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdate) SetNillableExpiresAt(v *time.Time) *APIKeyUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetExpiresAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate {
|
||||||
|
_u.mutation.ClearExpiresAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -350,6 +412,24 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.IPBlacklistCleared() {
|
if _u.mutation.IPBlacklistCleared() {
|
||||||
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.Quota(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedQuota(); ok {
|
||||||
|
_spec.AddField(apikey.FieldQuota, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.QuotaUsed(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedQuotaUsed(); ok {
|
||||||
|
_spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||||
|
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ExpiresAtCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
@@ -611,6 +691,68 @@ func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (_u *APIKeyUpdateOne) SetQuota(v float64) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.ResetQuota()
|
||||||
|
_u.mutation.SetQuota(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuota sets the "quota" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdateOne) SetNillableQuota(v *float64) *APIKeyUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetQuota(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds value to the "quota" field.
|
||||||
|
func (_u *APIKeyUpdateOne) AddQuota(v float64) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.AddQuota(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (_u *APIKeyUpdateOne) SetQuotaUsed(v float64) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.ResetQuotaUsed()
|
||||||
|
_u.mutation.SetQuotaUsed(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdateOne) SetNillableQuotaUsed(v *float64) *APIKeyUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetQuotaUsed(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds value to the "quota_used" field.
|
||||||
|
func (_u *APIKeyUpdateOne) AddQuotaUsed(v float64) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.AddQuotaUsed(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (_u *APIKeyUpdateOne) SetExpiresAt(v time.Time) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.SetExpiresAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||||
|
func (_u *APIKeyUpdateOne) SetNillableExpiresAt(v *time.Time) *APIKeyUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetExpiresAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne {
|
||||||
|
_u.mutation.ClearExpiresAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -821,6 +963,24 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
|
|||||||
if _u.mutation.IPBlacklistCleared() {
|
if _u.mutation.IPBlacklistCleared() {
|
||||||
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.Quota(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedQuota(); ok {
|
||||||
|
_spec.AddField(apikey.FieldQuota, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.QuotaUsed(); ok {
|
||||||
|
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedQuotaUsed(); ok {
|
||||||
|
_spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||||
|
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ExpiresAtCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
|
|||||||
@@ -56,10 +56,16 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
// 无效请求兜底使用的分组 ID
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
// 是否启用模型路由配置
|
// 是否启用模型路由配置
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||||
|
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||||
|
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||||
|
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||||
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -166,13 +172,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -322,6 +328,13 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.FallbackGroupID = new(int64)
|
_m.FallbackGroupID = new(int64)
|
||||||
*_m.FallbackGroupID = value.Int64
|
*_m.FallbackGroupID = value.Int64
|
||||||
}
|
}
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.FallbackGroupIDOnInvalidRequest = new(int64)
|
||||||
|
*_m.FallbackGroupIDOnInvalidRequest = value.Int64
|
||||||
|
}
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
if value, ok := values[i].(*[]byte); !ok {
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
||||||
@@ -336,6 +349,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.ModelRoutingEnabled = value.Bool
|
_m.ModelRoutingEnabled = value.Bool
|
||||||
}
|
}
|
||||||
|
case group.FieldMcpXMLInject:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field mcp_xml_inject", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.McpXMLInject = value.Bool
|
||||||
|
}
|
||||||
|
case group.FieldSupportedModelScopes:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field supported_model_scopes", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.SupportedModelScopes); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field supported_model_scopes: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -487,11 +514,22 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.FallbackGroupIDOnInvalidRequest; v != nil {
|
||||||
|
builder.WriteString("fallback_group_id_on_invalid_request=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("model_routing=")
|
builder.WriteString("model_routing=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("model_routing_enabled=")
|
builder.WriteString("model_routing_enabled=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
|
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("mcp_xml_inject=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.McpXMLInject))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("supported_model_scopes=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,10 +53,16 @@ const (
|
|||||||
FieldClaudeCodeOnly = "claude_code_only"
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
FieldFallbackGroupID = "fallback_group_id"
|
FieldFallbackGroupID = "fallback_group_id"
|
||||||
|
// FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database.
|
||||||
|
FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request"
|
||||||
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
||||||
FieldModelRouting = "model_routing"
|
FieldModelRouting = "model_routing"
|
||||||
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
||||||
FieldModelRoutingEnabled = "model_routing_enabled"
|
FieldModelRoutingEnabled = "model_routing_enabled"
|
||||||
|
// FieldMcpXMLInject holds the string denoting the mcp_xml_inject field in the database.
|
||||||
|
FieldMcpXMLInject = "mcp_xml_inject"
|
||||||
|
// FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database.
|
||||||
|
FieldSupportedModelScopes = "supported_model_scopes"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -151,8 +157,11 @@ var Columns = []string{
|
|||||||
FieldImagePrice4k,
|
FieldImagePrice4k,
|
||||||
FieldClaudeCodeOnly,
|
FieldClaudeCodeOnly,
|
||||||
FieldFallbackGroupID,
|
FieldFallbackGroupID,
|
||||||
|
FieldFallbackGroupIDOnInvalidRequest,
|
||||||
FieldModelRouting,
|
FieldModelRouting,
|
||||||
FieldModelRoutingEnabled,
|
FieldModelRoutingEnabled,
|
||||||
|
FieldMcpXMLInject,
|
||||||
|
FieldSupportedModelScopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -212,6 +221,10 @@ var (
|
|||||||
DefaultClaudeCodeOnly bool
|
DefaultClaudeCodeOnly bool
|
||||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||||
DefaultModelRoutingEnabled bool
|
DefaultModelRoutingEnabled bool
|
||||||
|
// DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field.
|
||||||
|
DefaultMcpXMLInject bool
|
||||||
|
// DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field.
|
||||||
|
DefaultSupportedModelScopes []string
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -317,11 +330,21 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field.
|
||||||
|
func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
||||||
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByMcpXMLInject orders the results by the mcp_xml_inject field.
|
||||||
|
func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -150,11 +150,21 @@ func FallbackGroupID(v int64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ.
|
||||||
|
func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
||||||
func ModelRoutingEnabled(v bool) predicate.Group {
|
func ModelRoutingEnabled(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// McpXMLInject applies equality check predicate on the "mcp_xml_inject" field. It's identical to McpXMLInjectEQ.
|
||||||
|
func McpXMLInject(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1070,6 +1080,56 @@ func FallbackGroupIDNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest))
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
||||||
func ModelRoutingIsNil() predicate.Group {
|
func ModelRoutingIsNil() predicate.Group {
|
||||||
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
||||||
@@ -1090,6 +1150,16 @@ func ModelRoutingEnabledNEQ(v bool) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
|
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// McpXMLInjectEQ applies the EQ predicate on the "mcp_xml_inject" field.
|
||||||
|
func McpXMLInjectEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpXMLInjectNEQ applies the NEQ predicate on the "mcp_xml_inject" field.
|
||||||
|
func McpXMLInjectNEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate {
|
||||||
|
_c.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
||||||
_c.mutation.SetModelRouting(v)
|
_c.mutation.SetModelRouting(v)
|
||||||
@@ -306,6 +320,26 @@ func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (_c *GroupCreate) SetMcpXMLInject(v bool) *GroupCreate {
|
||||||
|
_c.mutation.SetMcpXMLInject(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableMcpXMLInject(v *bool) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetMcpXMLInject(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate {
|
||||||
|
_c.mutation.SetSupportedModelScopes(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -479,6 +513,14 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultModelRoutingEnabled
|
v := group.DefaultModelRoutingEnabled
|
||||||
_c.mutation.SetModelRoutingEnabled(v)
|
_c.mutation.SetModelRoutingEnabled(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.McpXMLInject(); !ok {
|
||||||
|
v := group.DefaultMcpXMLInject
|
||||||
|
_c.mutation.SetMcpXMLInject(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
|
||||||
|
v := group.DefaultSupportedModelScopes
|
||||||
|
_c.mutation.SetSupportedModelScopes(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -537,6 +579,12 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
||||||
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
|
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.McpXMLInject(); !ok {
|
||||||
|
return &ValidationError{Name: "mcp_xml_inject", err: errors.New(`ent: missing required field "Group.mcp_xml_inject"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
|
||||||
|
return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,6 +688,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
_node.FallbackGroupID = &value
|
_node.FallbackGroupID = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
_node.FallbackGroupIDOnInvalidRequest = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.ModelRouting(); ok {
|
if value, ok := _c.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
_node.ModelRouting = value
|
_node.ModelRouting = value
|
||||||
@@ -648,6 +700,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
_node.ModelRoutingEnabled = value
|
_node.ModelRoutingEnabled = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.McpXMLInject(); ok {
|
||||||
|
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
|
||||||
|
_node.McpXMLInject = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.SupportedModelScopes(); ok {
|
||||||
|
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
|
||||||
|
_node.SupportedModelScopes = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1128,6 +1188,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
|
||||||
|
u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
|
||||||
|
u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert {
|
||||||
|
u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
||||||
u.Set(group.FieldModelRouting, v)
|
u.Set(group.FieldModelRouting, v)
|
||||||
@@ -1158,6 +1242,30 @@ func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (u *GroupUpsert) SetMcpXMLInject(v bool) *GroupUpsert {
|
||||||
|
u.Set(group.FieldMcpXMLInject, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateMcpXMLInject() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldMcpXMLInject)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (u *GroupUpsert) SetSupportedModelScopes(v []string) *GroupUpsert {
|
||||||
|
u.Set(group.FieldSupportedModelScopes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldSupportedModelScopes)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1581,6 +1689,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
@@ -1616,6 +1752,34 @@ func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (u *GroupUpsertOne) SetMcpXMLInject(v bool) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetMcpXMLInject(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateMcpXMLInject() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateMcpXMLInject()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (u *GroupUpsertOne) SetSupportedModelScopes(v []string) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSupportedModelScopes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSupportedModelScopes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -2205,6 +2369,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
@@ -2240,6 +2432,34 @@ func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (u *GroupUpsertBulk) SetMcpXMLInject(v bool) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetMcpXMLInject(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateMcpXMLInject() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateMcpXMLInject()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (u *GroupUpsertBulk) SetSupportedModelScopes(v []string) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSupportedModelScopes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSupportedModelScopes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"entgo.io/ent/dialect/sql"
|
"entgo.io/ent/dialect/sql"
|
||||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/dialect/sql/sqljson"
|
||||||
"entgo.io/ent/schema/field"
|
"entgo.io/ent/schema/field"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/account"
|
"github.com/Wei-Shaw/sub2api/ent/account"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
@@ -395,6 +396,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate {
|
||||||
|
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
||||||
_u.mutation.SetModelRouting(v)
|
_u.mutation.SetModelRouting(v)
|
||||||
@@ -421,6 +449,32 @@ func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (_u *GroupUpdate) SetMcpXMLInject(v bool) *GroupUpdate {
|
||||||
|
_u.mutation.SetMcpXMLInject(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableMcpXMLInject(v *bool) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetMcpXMLInject(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (_u *GroupUpdate) SetSupportedModelScopes(v []string) *GroupUpdate {
|
||||||
|
_u.mutation.SetSupportedModelScopes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendSupportedModelScopes appends value to the "supported_model_scopes" field.
|
||||||
|
func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate {
|
||||||
|
_u.mutation.AppendSupportedModelScopes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -829,6 +883,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
@@ -838,6 +901,17 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.McpXMLInject(); ok {
|
||||||
|
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.SupportedModelScopes(); ok {
|
||||||
|
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, group.FieldSupportedModelScopes, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1513,6 +1587,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne {
|
||||||
|
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
||||||
_u.mutation.SetModelRouting(v)
|
_u.mutation.SetModelRouting(v)
|
||||||
@@ -1539,6 +1640,32 @@ func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOn
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (_u *GroupUpdateOne) SetMcpXMLInject(v bool) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetMcpXMLInject(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableMcpXMLInject(v *bool) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetMcpXMLInject(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (_u *GroupUpdateOne) SetSupportedModelScopes(v []string) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetSupportedModelScopes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendSupportedModelScopes appends value to the "supported_model_scopes" field.
|
||||||
|
func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne {
|
||||||
|
_u.mutation.AppendSupportedModelScopes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1977,6 +2104,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
@@ -1986,6 +2122,17 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.McpXMLInject(); ok {
|
||||||
|
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.SupportedModelScopes(); ok {
|
||||||
|
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, group.FieldSupportedModelScopes, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ var (
|
|||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
|
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
|
||||||
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
|
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
|
||||||
|
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "expires_at", Type: field.TypeTime, Nullable: true},
|
||||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "user_id", Type: field.TypeInt64},
|
{Name: "user_id", Type: field.TypeInt64},
|
||||||
}
|
}
|
||||||
@@ -31,13 +34,13 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "api_keys_groups_api_keys",
|
Symbol: "api_keys_groups_api_keys",
|
||||||
Columns: []*schema.Column{APIKeysColumns[9]},
|
Columns: []*schema.Column{APIKeysColumns[12]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "api_keys_users_api_keys",
|
Symbol: "api_keys_users_api_keys",
|
||||||
Columns: []*schema.Column{APIKeysColumns[10]},
|
Columns: []*schema.Column{APIKeysColumns[13]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
@@ -46,12 +49,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "apikey_user_id",
|
Name: "apikey_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{APIKeysColumns[10]},
|
Columns: []*schema.Column{APIKeysColumns[13]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "apikey_group_id",
|
Name: "apikey_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{APIKeysColumns[9]},
|
Columns: []*schema.Column{APIKeysColumns[12]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "apikey_status",
|
Name: "apikey_status",
|
||||||
@@ -63,6 +66,16 @@ var (
|
|||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{APIKeysColumns[3]},
|
Columns: []*schema.Column{APIKeysColumns[3]},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "apikey_quota_quota_used",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "apikey_expires_at",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{APIKeysColumns[11]},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// AccountsColumns holds the columns for the "accounts" table.
|
// AccountsColumns holds the columns for the "accounts" table.
|
||||||
@@ -318,8 +331,11 @@ var (
|
|||||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
|
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
||||||
|
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||||
|
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
|
|||||||
@@ -79,6 +79,11 @@ type APIKeyMutation struct {
|
|||||||
appendip_whitelist []string
|
appendip_whitelist []string
|
||||||
ip_blacklist *[]string
|
ip_blacklist *[]string
|
||||||
appendip_blacklist []string
|
appendip_blacklist []string
|
||||||
|
quota *float64
|
||||||
|
addquota *float64
|
||||||
|
quota_used *float64
|
||||||
|
addquota_used *float64
|
||||||
|
expires_at *time.Time
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
user *int64
|
user *int64
|
||||||
cleareduser bool
|
cleareduser bool
|
||||||
@@ -634,6 +639,167 @@ func (m *APIKeyMutation) ResetIPBlacklist() {
|
|||||||
delete(m.clearedFields, apikey.FieldIPBlacklist)
|
delete(m.clearedFields, apikey.FieldIPBlacklist)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQuota sets the "quota" field.
|
||||||
|
func (m *APIKeyMutation) SetQuota(f float64) {
|
||||||
|
m.quota = &f
|
||||||
|
m.addquota = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quota returns the value of the "quota" field in the mutation.
|
||||||
|
func (m *APIKeyMutation) Quota() (r float64, exists bool) {
|
||||||
|
v := m.quota
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldQuota returns the old "quota" field's value of the APIKey entity.
|
||||||
|
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *APIKeyMutation) OldQuota(ctx context.Context) (v float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldQuota is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldQuota requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldQuota: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.Quota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuota adds f to the "quota" field.
|
||||||
|
func (m *APIKeyMutation) AddQuota(f float64) {
|
||||||
|
if m.addquota != nil {
|
||||||
|
*m.addquota += f
|
||||||
|
} else {
|
||||||
|
m.addquota = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedQuota returns the value that was added to the "quota" field in this mutation.
|
||||||
|
func (m *APIKeyMutation) AddedQuota() (r float64, exists bool) {
|
||||||
|
v := m.addquota
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetQuota resets all changes to the "quota" field.
|
||||||
|
func (m *APIKeyMutation) ResetQuota() {
|
||||||
|
m.quota = nil
|
||||||
|
m.addquota = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaUsed sets the "quota_used" field.
|
||||||
|
func (m *APIKeyMutation) SetQuotaUsed(f float64) {
|
||||||
|
m.quota_used = &f
|
||||||
|
m.addquota_used = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaUsed returns the value of the "quota_used" field in the mutation.
|
||||||
|
func (m *APIKeyMutation) QuotaUsed() (r float64, exists bool) {
|
||||||
|
v := m.quota_used
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldQuotaUsed returns the old "quota_used" field's value of the APIKey entity.
|
||||||
|
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *APIKeyMutation) OldQuotaUsed(ctx context.Context) (v float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldQuotaUsed is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldQuotaUsed requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldQuotaUsed: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.QuotaUsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuotaUsed adds f to the "quota_used" field.
|
||||||
|
func (m *APIKeyMutation) AddQuotaUsed(f float64) {
|
||||||
|
if m.addquota_used != nil {
|
||||||
|
*m.addquota_used += f
|
||||||
|
} else {
|
||||||
|
m.addquota_used = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedQuotaUsed returns the value that was added to the "quota_used" field in this mutation.
|
||||||
|
func (m *APIKeyMutation) AddedQuotaUsed() (r float64, exists bool) {
|
||||||
|
v := m.addquota_used
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetQuotaUsed resets all changes to the "quota_used" field.
|
||||||
|
func (m *APIKeyMutation) ResetQuotaUsed() {
|
||||||
|
m.quota_used = nil
|
||||||
|
m.addquota_used = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (m *APIKeyMutation) SetExpiresAt(t time.Time) {
|
||||||
|
m.expires_at = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAt returns the value of the "expires_at" field in the mutation.
|
||||||
|
func (m *APIKeyMutation) ExpiresAt() (r time.Time, exists bool) {
|
||||||
|
v := m.expires_at
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldExpiresAt returns the old "expires_at" field's value of the APIKey entity.
|
||||||
|
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *APIKeyMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldExpiresAt requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.ExpiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||||
|
func (m *APIKeyMutation) ClearExpiresAt() {
|
||||||
|
m.expires_at = nil
|
||||||
|
m.clearedFields[apikey.FieldExpiresAt] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
|
||||||
|
func (m *APIKeyMutation) ExpiresAtCleared() bool {
|
||||||
|
_, ok := m.clearedFields[apikey.FieldExpiresAt]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetExpiresAt resets all changes to the "expires_at" field.
|
||||||
|
func (m *APIKeyMutation) ResetExpiresAt() {
|
||||||
|
m.expires_at = nil
|
||||||
|
delete(m.clearedFields, apikey.FieldExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
// ClearUser clears the "user" edge to the User entity.
|
// ClearUser clears the "user" edge to the User entity.
|
||||||
func (m *APIKeyMutation) ClearUser() {
|
func (m *APIKeyMutation) ClearUser() {
|
||||||
m.cleareduser = true
|
m.cleareduser = true
|
||||||
@@ -776,7 +942,7 @@ func (m *APIKeyMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *APIKeyMutation) Fields() []string {
|
func (m *APIKeyMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 10)
|
fields := make([]string, 0, 13)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, apikey.FieldCreatedAt)
|
fields = append(fields, apikey.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -807,6 +973,15 @@ func (m *APIKeyMutation) Fields() []string {
|
|||||||
if m.ip_blacklist != nil {
|
if m.ip_blacklist != nil {
|
||||||
fields = append(fields, apikey.FieldIPBlacklist)
|
fields = append(fields, apikey.FieldIPBlacklist)
|
||||||
}
|
}
|
||||||
|
if m.quota != nil {
|
||||||
|
fields = append(fields, apikey.FieldQuota)
|
||||||
|
}
|
||||||
|
if m.quota_used != nil {
|
||||||
|
fields = append(fields, apikey.FieldQuotaUsed)
|
||||||
|
}
|
||||||
|
if m.expires_at != nil {
|
||||||
|
fields = append(fields, apikey.FieldExpiresAt)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -835,6 +1010,12 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.IPWhitelist()
|
return m.IPWhitelist()
|
||||||
case apikey.FieldIPBlacklist:
|
case apikey.FieldIPBlacklist:
|
||||||
return m.IPBlacklist()
|
return m.IPBlacklist()
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
return m.Quota()
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
return m.QuotaUsed()
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
return m.ExpiresAt()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -864,6 +1045,12 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
|
|||||||
return m.OldIPWhitelist(ctx)
|
return m.OldIPWhitelist(ctx)
|
||||||
case apikey.FieldIPBlacklist:
|
case apikey.FieldIPBlacklist:
|
||||||
return m.OldIPBlacklist(ctx)
|
return m.OldIPBlacklist(ctx)
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
return m.OldQuota(ctx)
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
return m.OldQuotaUsed(ctx)
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
return m.OldExpiresAt(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown APIKey field %s", name)
|
return nil, fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -943,6 +1130,27 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetIPBlacklist(v)
|
m.SetIPBlacklist(v)
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetQuota(v)
|
||||||
|
return nil
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetQuotaUsed(v)
|
||||||
|
return nil
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
v, ok := value.(time.Time)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetExpiresAt(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey field %s", name)
|
return fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -951,6 +1159,12 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
|
|||||||
// this mutation.
|
// this mutation.
|
||||||
func (m *APIKeyMutation) AddedFields() []string {
|
func (m *APIKeyMutation) AddedFields() []string {
|
||||||
var fields []string
|
var fields []string
|
||||||
|
if m.addquota != nil {
|
||||||
|
fields = append(fields, apikey.FieldQuota)
|
||||||
|
}
|
||||||
|
if m.addquota_used != nil {
|
||||||
|
fields = append(fields, apikey.FieldQuotaUsed)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -959,6 +1173,10 @@ func (m *APIKeyMutation) AddedFields() []string {
|
|||||||
// was not set, or was not defined in the schema.
|
// was not set, or was not defined in the schema.
|
||||||
func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) {
|
func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) {
|
||||||
switch name {
|
switch name {
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
return m.AddedQuota()
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
return m.AddedQuotaUsed()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -968,6 +1186,20 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
// type.
|
// type.
|
||||||
func (m *APIKeyMutation) AddField(name string, value ent.Value) error {
|
func (m *APIKeyMutation) AddField(name string, value ent.Value) error {
|
||||||
switch name {
|
switch name {
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddQuota(v)
|
||||||
|
return nil
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddQuotaUsed(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey numeric field %s", name)
|
return fmt.Errorf("unknown APIKey numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -988,6 +1220,9 @@ func (m *APIKeyMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(apikey.FieldIPBlacklist) {
|
if m.FieldCleared(apikey.FieldIPBlacklist) {
|
||||||
fields = append(fields, apikey.FieldIPBlacklist)
|
fields = append(fields, apikey.FieldIPBlacklist)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(apikey.FieldExpiresAt) {
|
||||||
|
fields = append(fields, apikey.FieldExpiresAt)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1014,6 +1249,9 @@ func (m *APIKeyMutation) ClearField(name string) error {
|
|||||||
case apikey.FieldIPBlacklist:
|
case apikey.FieldIPBlacklist:
|
||||||
m.ClearIPBlacklist()
|
m.ClearIPBlacklist()
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
m.ClearExpiresAt()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey nullable field %s", name)
|
return fmt.Errorf("unknown APIKey nullable field %s", name)
|
||||||
}
|
}
|
||||||
@@ -1052,6 +1290,15 @@ func (m *APIKeyMutation) ResetField(name string) error {
|
|||||||
case apikey.FieldIPBlacklist:
|
case apikey.FieldIPBlacklist:
|
||||||
m.ResetIPBlacklist()
|
m.ResetIPBlacklist()
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldQuota:
|
||||||
|
m.ResetQuota()
|
||||||
|
return nil
|
||||||
|
case apikey.FieldQuotaUsed:
|
||||||
|
m.ResetQuotaUsed()
|
||||||
|
return nil
|
||||||
|
case apikey.FieldExpiresAt:
|
||||||
|
m.ResetExpiresAt()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey field %s", name)
|
return fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5506,61 +5753,66 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
|
|||||||
// GroupMutation represents an operation that mutates the Group nodes in the graph.
|
// GroupMutation represents an operation that mutates the Group nodes in the graph.
|
||||||
type GroupMutation struct {
|
type GroupMutation struct {
|
||||||
config
|
config
|
||||||
op Op
|
op Op
|
||||||
typ string
|
typ string
|
||||||
id *int64
|
id *int64
|
||||||
created_at *time.Time
|
created_at *time.Time
|
||||||
updated_at *time.Time
|
updated_at *time.Time
|
||||||
deleted_at *time.Time
|
deleted_at *time.Time
|
||||||
name *string
|
name *string
|
||||||
description *string
|
description *string
|
||||||
rate_multiplier *float64
|
rate_multiplier *float64
|
||||||
addrate_multiplier *float64
|
addrate_multiplier *float64
|
||||||
is_exclusive *bool
|
is_exclusive *bool
|
||||||
status *string
|
status *string
|
||||||
platform *string
|
platform *string
|
||||||
subscription_type *string
|
subscription_type *string
|
||||||
daily_limit_usd *float64
|
daily_limit_usd *float64
|
||||||
adddaily_limit_usd *float64
|
adddaily_limit_usd *float64
|
||||||
weekly_limit_usd *float64
|
weekly_limit_usd *float64
|
||||||
addweekly_limit_usd *float64
|
addweekly_limit_usd *float64
|
||||||
monthly_limit_usd *float64
|
monthly_limit_usd *float64
|
||||||
addmonthly_limit_usd *float64
|
addmonthly_limit_usd *float64
|
||||||
default_validity_days *int
|
default_validity_days *int
|
||||||
adddefault_validity_days *int
|
adddefault_validity_days *int
|
||||||
image_price_1k *float64
|
image_price_1k *float64
|
||||||
addimage_price_1k *float64
|
addimage_price_1k *float64
|
||||||
image_price_2k *float64
|
image_price_2k *float64
|
||||||
addimage_price_2k *float64
|
addimage_price_2k *float64
|
||||||
image_price_4k *float64
|
image_price_4k *float64
|
||||||
addimage_price_4k *float64
|
addimage_price_4k *float64
|
||||||
claude_code_only *bool
|
claude_code_only *bool
|
||||||
fallback_group_id *int64
|
fallback_group_id *int64
|
||||||
addfallback_group_id *int64
|
addfallback_group_id *int64
|
||||||
model_routing *map[string][]int64
|
fallback_group_id_on_invalid_request *int64
|
||||||
model_routing_enabled *bool
|
addfallback_group_id_on_invalid_request *int64
|
||||||
clearedFields map[string]struct{}
|
model_routing *map[string][]int64
|
||||||
api_keys map[int64]struct{}
|
model_routing_enabled *bool
|
||||||
removedapi_keys map[int64]struct{}
|
mcp_xml_inject *bool
|
||||||
clearedapi_keys bool
|
supported_model_scopes *[]string
|
||||||
redeem_codes map[int64]struct{}
|
appendsupported_model_scopes []string
|
||||||
removedredeem_codes map[int64]struct{}
|
clearedFields map[string]struct{}
|
||||||
clearedredeem_codes bool
|
api_keys map[int64]struct{}
|
||||||
subscriptions map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
removedsubscriptions map[int64]struct{}
|
clearedapi_keys bool
|
||||||
clearedsubscriptions bool
|
redeem_codes map[int64]struct{}
|
||||||
usage_logs map[int64]struct{}
|
removedredeem_codes map[int64]struct{}
|
||||||
removedusage_logs map[int64]struct{}
|
clearedredeem_codes bool
|
||||||
clearedusage_logs bool
|
subscriptions map[int64]struct{}
|
||||||
accounts map[int64]struct{}
|
removedsubscriptions map[int64]struct{}
|
||||||
removedaccounts map[int64]struct{}
|
clearedsubscriptions bool
|
||||||
clearedaccounts bool
|
usage_logs map[int64]struct{}
|
||||||
allowed_users map[int64]struct{}
|
removedusage_logs map[int64]struct{}
|
||||||
removedallowed_users map[int64]struct{}
|
clearedusage_logs bool
|
||||||
clearedallowed_users bool
|
accounts map[int64]struct{}
|
||||||
done bool
|
removedaccounts map[int64]struct{}
|
||||||
oldValue func(context.Context) (*Group, error)
|
clearedaccounts bool
|
||||||
predicates []predicate.Group
|
allowed_users map[int64]struct{}
|
||||||
|
removedallowed_users map[int64]struct{}
|
||||||
|
clearedallowed_users bool
|
||||||
|
done bool
|
||||||
|
oldValue func(context.Context) (*Group, error)
|
||||||
|
predicates []predicate.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ent.Mutation = (*GroupMutation)(nil)
|
var _ ent.Mutation = (*GroupMutation)(nil)
|
||||||
@@ -6649,6 +6901,76 @@ func (m *GroupMutation) ResetFallbackGroupID() {
|
|||||||
delete(m.clearedFields, group.FieldFallbackGroupID)
|
delete(m.clearedFields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) {
|
||||||
|
m.fallback_group_id_on_invalid_request = &i
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
|
||||||
|
v := m.fallback_group_id_on_invalid_request
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.FallbackGroupIDOnInvalidRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) {
|
||||||
|
if m.addfallback_group_id_on_invalid_request != nil {
|
||||||
|
*m.addfallback_group_id_on_invalid_request += i
|
||||||
|
} else {
|
||||||
|
m.addfallback_group_id_on_invalid_request = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation.
|
||||||
|
func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
|
||||||
|
v := m.addfallback_group_id_on_invalid_request
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() {
|
||||||
|
m.fallback_group_id_on_invalid_request = nil
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool {
|
||||||
|
_, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() {
|
||||||
|
m.fallback_group_id_on_invalid_request = nil
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
|
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
|
||||||
m.model_routing = &value
|
m.model_routing = &value
|
||||||
@@ -6734,6 +7056,93 @@ func (m *GroupMutation) ResetModelRoutingEnabled() {
|
|||||||
m.model_routing_enabled = nil
|
m.model_routing_enabled = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMcpXMLInject sets the "mcp_xml_inject" field.
|
||||||
|
func (m *GroupMutation) SetMcpXMLInject(b bool) {
|
||||||
|
m.mcp_xml_inject = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation.
|
||||||
|
func (m *GroupMutation) McpXMLInject() (r bool, exists bool) {
|
||||||
|
v := m.mcp_xml_inject
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldMcpXMLInject requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.McpXMLInject, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field.
|
||||||
|
func (m *GroupMutation) ResetMcpXMLInject() {
|
||||||
|
m.mcp_xml_inject = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSupportedModelScopes sets the "supported_model_scopes" field.
|
||||||
|
func (m *GroupMutation) SetSupportedModelScopes(s []string) {
|
||||||
|
m.supported_model_scopes = &s
|
||||||
|
m.appendsupported_model_scopes = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation.
|
||||||
|
func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) {
|
||||||
|
v := m.supported_model_scopes
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.SupportedModelScopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendSupportedModelScopes adds s to the "supported_model_scopes" field.
|
||||||
|
func (m *GroupMutation) AppendSupportedModelScopes(s []string) {
|
||||||
|
m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation.
|
||||||
|
func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) {
|
||||||
|
if len(m.appendsupported_model_scopes) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.appendsupported_model_scopes, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field.
|
||||||
|
func (m *GroupMutation) ResetSupportedModelScopes() {
|
||||||
|
m.supported_model_scopes = nil
|
||||||
|
m.appendsupported_model_scopes = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -7092,7 +7501,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 21)
|
fields := make([]string, 0, 24)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -7150,12 +7559,21 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.fallback_group_id != nil {
|
if m.fallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.fallback_group_id_on_invalid_request != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
if m.model_routing != nil {
|
if m.model_routing != nil {
|
||||||
fields = append(fields, group.FieldModelRouting)
|
fields = append(fields, group.FieldModelRouting)
|
||||||
}
|
}
|
||||||
if m.model_routing_enabled != nil {
|
if m.model_routing_enabled != nil {
|
||||||
fields = append(fields, group.FieldModelRoutingEnabled)
|
fields = append(fields, group.FieldModelRoutingEnabled)
|
||||||
}
|
}
|
||||||
|
if m.mcp_xml_inject != nil {
|
||||||
|
fields = append(fields, group.FieldMcpXMLInject)
|
||||||
|
}
|
||||||
|
if m.supported_model_scopes != nil {
|
||||||
|
fields = append(fields, group.FieldSupportedModelScopes)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7202,10 +7620,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ClaudeCodeOnly()
|
return m.ClaudeCodeOnly()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.FallbackGroupID()
|
return m.FallbackGroupID()
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.FallbackGroupIDOnInvalidRequest()
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
return m.ModelRouting()
|
return m.ModelRouting()
|
||||||
case group.FieldModelRoutingEnabled:
|
case group.FieldModelRoutingEnabled:
|
||||||
return m.ModelRoutingEnabled()
|
return m.ModelRoutingEnabled()
|
||||||
|
case group.FieldMcpXMLInject:
|
||||||
|
return m.McpXMLInject()
|
||||||
|
case group.FieldSupportedModelScopes:
|
||||||
|
return m.SupportedModelScopes()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -7253,10 +7677,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldClaudeCodeOnly(ctx)
|
return m.OldClaudeCodeOnly(ctx)
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.OldFallbackGroupID(ctx)
|
return m.OldFallbackGroupID(ctx)
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.OldFallbackGroupIDOnInvalidRequest(ctx)
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
return m.OldModelRouting(ctx)
|
return m.OldModelRouting(ctx)
|
||||||
case group.FieldModelRoutingEnabled:
|
case group.FieldModelRoutingEnabled:
|
||||||
return m.OldModelRoutingEnabled(ctx)
|
return m.OldModelRoutingEnabled(ctx)
|
||||||
|
case group.FieldMcpXMLInject:
|
||||||
|
return m.OldMcpXMLInject(ctx)
|
||||||
|
case group.FieldSupportedModelScopes:
|
||||||
|
return m.OldSupportedModelScopes(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -7399,6 +7829,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetFallbackGroupID(v)
|
m.SetFallbackGroupID(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
v, ok := value.(map[string][]int64)
|
v, ok := value.(map[string][]int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -7413,6 +7850,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetModelRoutingEnabled(v)
|
m.SetModelRoutingEnabled(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldMcpXMLInject:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetMcpXMLInject(v)
|
||||||
|
return nil
|
||||||
|
case group.FieldSupportedModelScopes:
|
||||||
|
v, ok := value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetSupportedModelScopes(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -7448,6 +7899,9 @@ func (m *GroupMutation) AddedFields() []string {
|
|||||||
if m.addfallback_group_id != nil {
|
if m.addfallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.addfallback_group_id_on_invalid_request != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7474,6 +7928,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedImagePrice4k()
|
return m.AddedImagePrice4k()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.AddedFallbackGroupID()
|
return m.AddedFallbackGroupID()
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.AddedFallbackGroupIDOnInvalidRequest()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -7546,6 +8002,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddFallbackGroupID(v)
|
m.AddFallbackGroupID(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -7581,6 +8044,9 @@ func (m *GroupMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(group.FieldFallbackGroupID) {
|
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
if m.FieldCleared(group.FieldModelRouting) {
|
if m.FieldCleared(group.FieldModelRouting) {
|
||||||
fields = append(fields, group.FieldModelRouting)
|
fields = append(fields, group.FieldModelRouting)
|
||||||
}
|
}
|
||||||
@@ -7625,6 +8091,9 @@ func (m *GroupMutation) ClearField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ClearFallbackGroupID()
|
m.ClearFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
m.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
m.ClearModelRouting()
|
m.ClearModelRouting()
|
||||||
return nil
|
return nil
|
||||||
@@ -7693,12 +8162,21 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ResetFallbackGroupID()
|
m.ResetFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
m.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
m.ResetModelRouting()
|
m.ResetModelRouting()
|
||||||
return nil
|
return nil
|
||||||
case group.FieldModelRoutingEnabled:
|
case group.FieldModelRoutingEnabled:
|
||||||
m.ResetModelRoutingEnabled()
|
m.ResetModelRoutingEnabled()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldMcpXMLInject:
|
||||||
|
m.ResetMcpXMLInject()
|
||||||
|
return nil
|
||||||
|
case group.FieldSupportedModelScopes:
|
||||||
|
m.ResetSupportedModelScopes()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,6 +91,14 @@ func init() {
|
|||||||
apikey.DefaultStatus = apikeyDescStatus.Default.(string)
|
apikey.DefaultStatus = apikeyDescStatus.Default.(string)
|
||||||
// apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error)
|
apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error)
|
||||||
|
// apikeyDescQuota is the schema descriptor for quota field.
|
||||||
|
apikeyDescQuota := apikeyFields[7].Descriptor()
|
||||||
|
// apikey.DefaultQuota holds the default value on creation for the quota field.
|
||||||
|
apikey.DefaultQuota = apikeyDescQuota.Default.(float64)
|
||||||
|
// apikeyDescQuotaUsed is the schema descriptor for quota_used field.
|
||||||
|
apikeyDescQuotaUsed := apikeyFields[8].Descriptor()
|
||||||
|
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
|
||||||
|
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
|
||||||
accountMixin := schema.Account{}.Mixin()
|
accountMixin := schema.Account{}.Mixin()
|
||||||
accountMixinHooks1 := accountMixin[1].Hooks()
|
accountMixinHooks1 := accountMixin[1].Hooks()
|
||||||
account.Hooks[0] = accountMixinHooks1[0]
|
account.Hooks[0] = accountMixinHooks1[0]
|
||||||
@@ -334,9 +342,17 @@ func init() {
|
|||||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||||
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
|
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||||
|
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||||
|
groupDescMcpXMLInject := groupFields[19].Descriptor()
|
||||||
|
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||||
|
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||||
|
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||||
|
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||||
|
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||||
|
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||||
promocodeFields := schema.PromoCode{}.Fields()
|
promocodeFields := schema.PromoCode{}.Fields()
|
||||||
_ = promocodeFields
|
_ = promocodeFields
|
||||||
// promocodeDescCode is the schema descriptor for code field.
|
// promocodeDescCode is the schema descriptor for code field.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
|
|
||||||
"entgo.io/ent"
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
"entgo.io/ent/dialect/entsql"
|
"entgo.io/ent/dialect/entsql"
|
||||||
"entgo.io/ent/schema"
|
"entgo.io/ent/schema"
|
||||||
"entgo.io/ent/schema/edge"
|
"entgo.io/ent/schema/edge"
|
||||||
@@ -52,6 +53,23 @@ func (APIKey) Fields() []ent.Field {
|
|||||||
field.JSON("ip_blacklist", []string{}).
|
field.JSON("ip_blacklist", []string{}).
|
||||||
Optional().
|
Optional().
|
||||||
Comment("Blocked IPs/CIDRs"),
|
Comment("Blocked IPs/CIDRs"),
|
||||||
|
|
||||||
|
// ========== Quota fields ==========
|
||||||
|
// Quota limit in USD (0 = unlimited)
|
||||||
|
field.Float("quota").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||||
|
Default(0).
|
||||||
|
Comment("Quota limit in USD for this API key (0 = unlimited)"),
|
||||||
|
// Used quota amount
|
||||||
|
field.Float("quota_used").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||||
|
Default(0).
|
||||||
|
Comment("Used quota amount in USD"),
|
||||||
|
// Expiration time (nil = never expires)
|
||||||
|
field.Time("expires_at").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
Comment("Expiration time for this API key (null = never expires)"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,5 +95,8 @@ func (APIKey) Indexes() []ent.Index {
|
|||||||
index.Fields("group_id"),
|
index.Fields("group_id"),
|
||||||
index.Fields("status"),
|
index.Fields("status"),
|
||||||
index.Fields("deleted_at"),
|
index.Fields("deleted_at"),
|
||||||
|
// Index for quota queries
|
||||||
|
index.Fields("quota", "quota_used"),
|
||||||
|
index.Fields("expires_at"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
|
field.Int64("fallback_group_id_on_invalid_request").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
Comment("无效请求兜底使用的分组 ID"),
|
||||||
|
|
||||||
// 模型路由配置 (added by migration 040)
|
// 模型路由配置 (added by migration 040)
|
||||||
field.JSON("model_routing", map[string][]int64{}).
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
@@ -106,6 +110,17 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Bool("model_routing_enabled").
|
field.Bool("model_routing_enabled").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("是否启用模型路由配置"),
|
Comment("是否启用模型路由配置"),
|
||||||
|
|
||||||
|
// MCP XML 协议注入开关 (added by migration 042)
|
||||||
|
field.Bool("mcp_xml_inject").
|
||||||
|
Default(true).
|
||||||
|
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||||
|
|
||||||
|
// 支持的模型系列 (added by migration 046)
|
||||||
|
field.JSON("supported_model_scopes", []string{}).
|
||||||
|
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
|
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ go 1.25.6
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
entgo.io/ent v0.14.5
|
entgo.io/ent v0.14.5
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
@@ -11,7 +13,10 @@ require (
|
|||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/imroc/req/v3 v3.57.0
|
github.com/imroc/req/v3 v3.57.0
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/redis/go-redis/v9 v9.17.2
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
|
github.com/refraction-networking/utls v1.8.1
|
||||||
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6
|
github.com/shirou/gopsutil/v4 v4.25.6
|
||||||
github.com/spf13/viper v1.18.2
|
github.com/spf13/viper v1.18.2
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
@@ -25,13 +30,13 @@ require (
|
|||||||
golang.org/x/sync v0.19.0
|
golang.org/x/sync v0.19.0
|
||||||
golang.org/x/term v0.38.0
|
golang.org/x/term v0.38.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
modernc.org/sqlite v1.44.3
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
|
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
|
||||||
dario.cat/mergo v1.0.2 // indirect
|
dario.cat/mergo v1.0.2 // indirect
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
|
||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/agext/levenshtein v1.2.3 // indirect
|
github.com/agext/levenshtein v1.2.3 // indirect
|
||||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||||
@@ -48,7 +53,6 @@ require (
|
|||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgraph-io/ristretto v0.2.0 // indirect
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
@@ -107,13 +111,10 @@ require (
|
|||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/pquerna/otp v1.5.0 // indirect
|
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
|
||||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
@@ -149,12 +150,10 @@ require (
|
|||||||
golang.org/x/sys v0.39.0 // indirect
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
golang.org/x/text v0.32.0 // indirect
|
golang.org/x/text v0.32.0 // indirect
|
||||||
golang.org/x/tools v0.39.0 // indirect
|
golang.org/x/tools v0.39.0 // indirect
|
||||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
|
||||||
google.golang.org/grpc v1.75.1 // indirect
|
google.golang.org/grpc v1.75.1 // indirect
|
||||||
google.golang.org/protobuf v1.36.10 // indirect
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
modernc.org/libc v1.67.6 // indirect
|
modernc.org/libc v1.67.6 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
modernc.org/sqlite v1.44.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
|||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
|
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
|
||||||
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
|
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
|
||||||
|
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y=
|
||||||
|
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -113,6 +115,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
|
|||||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -123,6 +127,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
|||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||||
|
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
|
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
|
||||||
@@ -345,8 +352,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
|||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||||
@@ -374,9 +379,8 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
|||||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||||
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
|
|
||||||
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
|
||||||
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
|
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
|
||||||
|
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
|
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
|
||||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -399,12 +403,32 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
|
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||||
|
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||||
|
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||||
|
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||||
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||||
|
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY=
|
||||||
|
modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ const (
|
|||||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -36,6 +37,7 @@ const (
|
|||||||
RedeemTypeBalance = "balance"
|
RedeemTypeBalance = "balance"
|
||||||
RedeemTypeConcurrency = "concurrency"
|
RedeemTypeConcurrency = "concurrency"
|
||||||
RedeemTypeSubscription = "subscription"
|
RedeemTypeSubscription = "subscription"
|
||||||
|
RedeemTypeInvitation = "invitation"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromoCode status constants
|
// PromoCode status constants
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -102,7 +102,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
|||||||
@@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser
|
|||||||
return &code, nil
|
return &code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
|
||||||
|
return s.redeems, int64(len(s.redeems)), 100.0, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure stub implements interface.
|
// Ensure stub implements interface.
|
||||||
var _ service.AdminService = (*stubAdminService)(nil)
|
var _ service.AdminService = (*stubAdminService)(nil)
|
||||||
|
|||||||
@@ -35,14 +35,20 @@ type CreateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
|
// 从指定分组复制账号(创建后自动绑定)
|
||||||
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
@@ -58,14 +64,20 @@ type UpdateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
|
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||||
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all groups with pagination
|
// List handles listing all groups with pagination
|
||||||
@@ -155,22 +167,26 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -196,23 +212,27 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
|||||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||||
type GenerateRedeemCodesRequest struct {
|
type GenerateRedeemCodesRequest struct {
|
||||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||||
Value float64 `json:"value" binding:"min=0"`
|
Value float64 `json:"value" binding:"min=0"`
|
||||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
SMTPHost: settings.SMTPHost,
|
SMTPHost: settings.SMTPHost,
|
||||||
@@ -94,11 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
// UpdateSettingsRequest 更新设置请求
|
// UpdateSettingsRequest 更新设置请求
|
||||||
type UpdateSettingsRequest struct {
|
type UpdateSettingsRequest struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
@@ -291,6 +293,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
SMTPPort: req.SMTPPort,
|
SMTPPort: req.SMTPPort,
|
||||||
@@ -370,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
SMTPHost: updatedSettings.SMTPHost,
|
SMTPHost: updatedSettings.SMTPHost,
|
||||||
|
|||||||
@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, stats)
|
response.Success(c, stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||||
|
// GET /api/v1/admin/users/:id/balance-history
|
||||||
|
// Query params:
|
||||||
|
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||||
|
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
codeType := c.Query("type")
|
||||||
|
|
||||||
|
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to admin DTO (includes notes field for admin visibility)
|
||||||
|
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||||
|
for i := range codes {
|
||||||
|
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom response with total_recharged alongside pagination
|
||||||
|
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||||
|
if pages < 1 {
|
||||||
|
pages = 1
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"items": out,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": pageSize,
|
||||||
|
"pages": pages,
|
||||||
|
"total_recharged": totalRecharged,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -27,11 +28,13 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
|||||||
|
|
||||||
// CreateAPIKeyRequest represents the create API key request payload
|
// CreateAPIKeyRequest represents the create API key request payload
|
||||||
type CreateAPIKeyRequest struct {
|
type CreateAPIKeyRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
GroupID *int64 `json:"group_id"` // nullable
|
GroupID *int64 `json:"group_id"` // nullable
|
||||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
|
Quota *float64 `json:"quota"` // 配额限制 (USD)
|
||||||
|
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAPIKeyRequest represents the update API key request payload
|
// UpdateAPIKeyRequest represents the update API key request payload
|
||||||
@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct {
|
|||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
|
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
|
||||||
|
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
|
||||||
|
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing user's API keys with pagination
|
// List handles listing user's API keys with pagination
|
||||||
@@ -114,11 +120,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
svcReq := service.CreateAPIKeyRequest{
|
svcReq := service.CreateAPIKeyRequest{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
GroupID: req.GroupID,
|
GroupID: req.GroupID,
|
||||||
CustomKey: req.CustomKey,
|
CustomKey: req.CustomKey,
|
||||||
IPWhitelist: req.IPWhitelist,
|
IPWhitelist: req.IPWhitelist,
|
||||||
IPBlacklist: req.IPBlacklist,
|
IPBlacklist: req.IPBlacklist,
|
||||||
|
ExpiresInDays: req.ExpiresInDays,
|
||||||
|
}
|
||||||
|
if req.Quota != nil {
|
||||||
|
svcReq.Quota = *req.Quota
|
||||||
}
|
}
|
||||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
|||||||
svcReq := service.UpdateAPIKeyRequest{
|
svcReq := service.UpdateAPIKeyRequest{
|
||||||
IPWhitelist: req.IPWhitelist,
|
IPWhitelist: req.IPWhitelist,
|
||||||
IPBlacklist: req.IPBlacklist,
|
IPBlacklist: req.IPBlacklist,
|
||||||
|
Quota: req.Quota,
|
||||||
|
ResetQuota: req.ResetQuota,
|
||||||
}
|
}
|
||||||
if req.Name != "" {
|
if req.Name != "" {
|
||||||
svcReq.Name = &req.Name
|
svcReq.Name = &req.Name
|
||||||
@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
|||||||
if req.Status != "" {
|
if req.Status != "" {
|
||||||
svcReq.Status = &req.Status
|
svcReq.Status = &req.Status
|
||||||
}
|
}
|
||||||
|
// Parse expires_at if provided
|
||||||
|
if req.ExpiresAt != nil {
|
||||||
|
if *req.ExpiresAt == "" {
|
||||||
|
// Empty string means clear expiration
|
||||||
|
svcReq.ExpiresAt = nil
|
||||||
|
svcReq.ClearExpiration = true
|
||||||
|
} else {
|
||||||
|
t, err := time.Parse(time.RFC3339, *req.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid expires_at format: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
svcReq.ExpiresAt = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
|
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,23 +15,25 @@ import (
|
|||||||
|
|
||||||
// AuthHandler handles authentication-related requests
|
// AuthHandler handles authentication-related requests
|
||||||
type AuthHandler struct {
|
type AuthHandler struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
authService *service.AuthService
|
authService *service.AuthService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
settingSvc *service.SettingService
|
settingSvc *service.SettingService
|
||||||
promoService *service.PromoService
|
promoService *service.PromoService
|
||||||
totpService *service.TotpService
|
redeemService *service.RedeemService
|
||||||
|
totpService *service.TotpService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new AuthHandler
|
// NewAuthHandler creates a new AuthHandler
|
||||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
authService: authService,
|
authService: authService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
settingSvc: settingService,
|
settingSvc: settingService,
|
||||||
promoService: promoService,
|
promoService: promoService,
|
||||||
totpService: totpService,
|
redeemService: redeemService,
|
||||||
|
totpService: totpService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +43,8 @@ type RegisterRequest struct {
|
|||||||
Password string `json:"password" binding:"required,min=6"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
VerifyCode string `json:"verify_code"`
|
VerifyCode string `json:"verify_code"`
|
||||||
TurnstileToken string `json:"turnstile_token"`
|
TurnstileToken string `json:"turnstile_token"`
|
||||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||||
|
InvitationCode string `json:"invitation_code"` // 邀请码
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVerifyCodeRequest 发送验证码请求
|
// SendVerifyCodeRequest 发送验证码请求
|
||||||
@@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -346,6 +349,67 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateInvitationCodeRequest 验证邀请码请求
|
||||||
|
type ValidateInvitationCodeRequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateInvitationCodeResponse 验证邀请码响应
|
||||||
|
type ValidateInvitationCodeResponse struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
ErrorCode string `json:"error_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
|
||||||
|
// POST /api/v1/auth/validate-invitation-code
|
||||||
|
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
|
||||||
|
// 检查邀请码功能是否启用
|
||||||
|
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
|
||||||
|
response.Success(c, ValidateInvitationCodeResponse{
|
||||||
|
Valid: false,
|
||||||
|
ErrorCode: "INVITATION_CODE_DISABLED",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ValidateInvitationCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证邀请码
|
||||||
|
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
|
||||||
|
if err != nil {
|
||||||
|
response.Success(c, ValidateInvitationCodeResponse{
|
||||||
|
Valid: false,
|
||||||
|
ErrorCode: "INVITATION_CODE_NOT_FOUND",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查类型和状态
|
||||||
|
if redeemCode.Type != service.RedeemTypeInvitation {
|
||||||
|
response.Success(c, ValidateInvitationCodeResponse{
|
||||||
|
Valid: false,
|
||||||
|
ErrorCode: "INVITATION_CODE_INVALID",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if redeemCode.Status != service.StatusUnused {
|
||||||
|
response.Success(c, ValidateInvitationCodeResponse{
|
||||||
|
Valid: false,
|
||||||
|
ErrorCode: "INVITATION_CODE_USED",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, ValidateInvitationCodeResponse{
|
||||||
|
Valid: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ForgotPasswordRequest 忘记密码请求
|
// ForgotPasswordRequest 忘记密码请求
|
||||||
type ForgotPasswordRequest struct {
|
type ForgotPasswordRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
|||||||
Status: k.Status,
|
Status: k.Status,
|
||||||
IPWhitelist: k.IPWhitelist,
|
IPWhitelist: k.IPWhitelist,
|
||||||
IPBlacklist: k.IPBlacklist,
|
IPBlacklist: k.IPBlacklist,
|
||||||
|
Quota: k.Quota,
|
||||||
|
QuotaUsed: k.QuotaUsed,
|
||||||
|
ExpiresAt: k.ExpiresAt,
|
||||||
CreatedAt: k.CreatedAt,
|
CreatedAt: k.CreatedAt,
|
||||||
UpdatedAt: k.UpdatedAt,
|
UpdatedAt: k.UpdatedAt,
|
||||||
User: UserFromServiceShallow(k.User),
|
User: UserFromServiceShallow(k.User),
|
||||||
@@ -105,10 +108,12 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
AccountCount: g.AccountCount,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
|
AccountCount: g.AccountCount,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
@@ -138,8 +143,10 @@ func groupFromServiceBase(g *service.Group) Group {
|
|||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
CreatedAt: g.CreatedAt,
|
// 无效请求兜底分组
|
||||||
UpdatedAt: g.UpdatedAt,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
|
CreatedAt: g.CreatedAt,
|
||||||
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,6 +211,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
|
||||||
|
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
|
||||||
|
now := time.Now()
|
||||||
|
for scope, remainingSec := range scopeLimits {
|
||||||
|
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
|
||||||
|
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
|
||||||
|
RemainingSec: remainingSec,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,6 +384,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
AccountID: l.AccountID,
|
AccountID: l.AccountID,
|
||||||
RequestID: l.RequestID,
|
RequestID: l.RequestID,
|
||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
GroupID: l.GroupID,
|
GroupID: l.GroupID,
|
||||||
SubscriptionID: l.SubscriptionID,
|
SubscriptionID: l.SubscriptionID,
|
||||||
InputTokens: l.InputTokens,
|
InputTokens: l.InputTokens,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ type SystemSettings struct {
|
|||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
|
|
||||||
@@ -63,6 +64,7 @@ type PublicSettings struct {
|
|||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
|
|||||||
@@ -2,6 +2,11 @@ package dto
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
type ScopeRateLimitInfo struct {
|
||||||
|
ResetAt time.Time `json:"reset_at"`
|
||||||
|
RemainingSec int64 `json:"remaining_sec"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -27,16 +32,19 @@ type AdminUser struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
IPWhitelist []string `json:"ip_whitelist"`
|
IPWhitelist []string `json:"ip_whitelist"`
|
||||||
IPBlacklist []string `json:"ip_blacklist"`
|
IPBlacklist []string `json:"ip_blacklist"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
|
||||||
|
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
User *User `json:"user,omitempty"`
|
||||||
Group *Group `json:"group,omitempty"`
|
Group *Group `json:"group,omitempty"`
|
||||||
@@ -64,6 +72,8 @@ type Group struct {
|
|||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
// 无效请求兜底分组
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
@@ -78,8 +88,13 @@ type AdminGroup struct {
|
|||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
|
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
|
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
@@ -108,6 +123,9 @@ type Account struct {
|
|||||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||||
OverloadUntil *time.Time `json:"overload_until"`
|
OverloadUntil *time.Time `json:"overload_until"`
|
||||||
|
|
||||||
|
// Antigravity scope 级限流状态(从 extra 提取)
|
||||||
|
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
|
||||||
|
|
||||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||||
|
|
||||||
@@ -222,6 +240,9 @@ type UsageLog struct {
|
|||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||||
|
// nil means not provided / not applicable.
|
||||||
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
SubscriptionID *int64 `json:"subscription_id"`
|
SubscriptionID *int64 `json:"subscription_id"`
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
@@ -31,6 +32,7 @@ type GatewayHandler struct {
|
|||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
usageService *service.UsageService
|
usageService *service.UsageService
|
||||||
|
apiKeyService *service.APIKeyService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
maxAccountSwitchesGemini int
|
maxAccountSwitchesGemini int
|
||||||
@@ -45,6 +47,7 @@ func NewGatewayHandler(
|
|||||||
concurrencyService *service.ConcurrencyService,
|
concurrencyService *service.ConcurrencyService,
|
||||||
billingCacheService *service.BillingCacheService,
|
billingCacheService *service.BillingCacheService,
|
||||||
usageService *service.UsageService,
|
usageService *service.UsageService,
|
||||||
|
apiKeyService *service.APIKeyService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *GatewayHandler {
|
) *GatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
@@ -66,6 +69,7 @@ func NewGatewayHandler(
|
|||||||
userService: userService,
|
userService: userService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
usageService: usageService,
|
usageService: usageService,
|
||||||
|
apiKeyService: apiKeyService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||||
@@ -281,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
|
requestCtx := c.Request.Context()
|
||||||
|
if switchCount > 0 {
|
||||||
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||||
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
|
||||||
}
|
}
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
@@ -316,13 +324,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -331,139 +340,193 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
currentAPIKey := apiKey
|
||||||
switchCount := 0
|
currentSubscription := subscription
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
var fallbackGroupID *int64
|
||||||
lastFailoverStatus := 0
|
if apiKey.Group != nil {
|
||||||
|
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
|
||||||
|
}
|
||||||
|
fallbackUsed := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
switchCount := 0
|
||||||
if err != nil {
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
if len(failedAccountIDs) == 0 {
|
lastFailoverStatus := 0
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
retryWithFallback := false
|
||||||
return
|
|
||||||
}
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account := selection.Account
|
|
||||||
setOpsSelectedAccount(c, account.ID)
|
|
||||||
|
|
||||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
for {
|
||||||
if account.IsInterceptWarmupEnabled() {
|
// 选择支持该模型的账号
|
||||||
interceptType := detectInterceptType(body)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||||
if interceptType != InterceptTypeNone {
|
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
|
||||||
selection.ReleaseFunc()
|
|
||||||
}
|
|
||||||
if reqStream {
|
|
||||||
sendMockInterceptStream(c, reqModel, interceptType)
|
|
||||||
} else {
|
|
||||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
|
||||||
if !selection.Acquired {
|
|
||||||
if selection.WaitPlan == nil {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
accountWaitCounted := false
|
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
if len(failedAccountIDs) == 0 {
|
||||||
} else if !canWait {
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err == nil && canWait {
|
|
||||||
accountWaitCounted = true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
|
||||||
c,
|
|
||||||
account.ID,
|
|
||||||
selection.WaitPlan.MaxConcurrency,
|
|
||||||
selection.WaitPlan.Timeout,
|
|
||||||
reqStream,
|
|
||||||
&streamStarted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
|
||||||
var result *service.ForwardResult
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
} else {
|
|
||||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
|
||||||
}
|
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
var failoverErr *service.UpstreamFailoverError
|
|
||||||
if errors.As(err, &failoverErr) {
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
return
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
account := selection.Account
|
||||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
|
if account.IsInterceptWarmupEnabled() {
|
||||||
|
interceptType := detectInterceptType(body)
|
||||||
|
if interceptType != InterceptTypeNone {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
if reqStream {
|
||||||
|
sendMockInterceptStream(c, reqModel, interceptType)
|
||||||
|
} else {
|
||||||
|
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 获取账号并发槽位
|
||||||
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
|
if !selection.Acquired {
|
||||||
|
if selection.WaitPlan == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && canWait {
|
||||||
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
|
// 转发请求 - 根据账号平台分流
|
||||||
|
var result *service.ForwardResult
|
||||||
|
requestCtx := c.Request.Context()
|
||||||
|
if switchCount > 0 {
|
||||||
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||||
|
}
|
||||||
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
|
||||||
|
} else {
|
||||||
|
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
||||||
|
}
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var promptTooLongErr *service.PromptTooLongError
|
||||||
|
if errors.As(err, &promptTooLongErr) {
|
||||||
|
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
|
||||||
|
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
|
||||||
|
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Resolve fallback group failed: %v", err)
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fallbackGroup.Platform != service.PlatformAnthropic ||
|
||||||
|
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
|
||||||
|
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
currentAPIKey = fallbackAPIKey
|
||||||
|
currentSubscription = nil
|
||||||
|
fallbackUsed = true
|
||||||
|
retryWithFallback = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
|
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: currentAPIKey,
|
||||||
|
User: currentAPIKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: currentSubscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account, userAgent, clientIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !retryWithFallback {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
|
||||||
clientIP := ip.GetClientIP(c)
|
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
APIKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: usedAccount,
|
|
||||||
Subscription: subscription,
|
|
||||||
UserAgent: ua,
|
|
||||||
IPAddress: clientIP,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
|
||||||
}(result, account, userAgent, clientIP)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,6 +590,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
|
||||||
|
if apiKey == nil || group == nil {
|
||||||
|
return apiKey
|
||||||
|
}
|
||||||
|
cloned := *apiKey
|
||||||
|
groupID := group.ID
|
||||||
|
cloned.GroupID = &groupID
|
||||||
|
cloned.Group = group
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
// Usage handles getting account balance and usage statistics for CC Switch integration
|
// Usage handles getting account balance and usage statistics for CC Switch integration
|
||||||
// GET /v1/usage
|
// GET /v1/usage
|
||||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||||
@@ -779,6 +853,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 5) forward (根据平台分流)
|
// 5) forward (根据平台分流)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
|
requestCtx := c.Request.Context()
|
||||||
|
if switchCount > 0 {
|
||||||
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||||
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
|
||||||
}
|
}
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
@@ -366,18 +371,22 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 6) record usage async
|
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
APIKey: apiKey,
|
Result: result,
|
||||||
User: apiKey.User,
|
APIKey: apiKey,
|
||||||
Account: usedAccount,
|
User: apiKey.User,
|
||||||
Subscription: subscription,
|
Account: usedAccount,
|
||||||
UserAgent: ua,
|
Subscription: subscription,
|
||||||
IPAddress: ip,
|
UserAgent: ua,
|
||||||
|
IPAddress: ip,
|
||||||
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
type OpenAIGatewayHandler struct {
|
type OpenAIGatewayHandler struct {
|
||||||
gatewayService *service.OpenAIGatewayService
|
gatewayService *service.OpenAIGatewayService
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
|
apiKeyService *service.APIKeyService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
}
|
}
|
||||||
@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler(
|
|||||||
gatewayService *service.OpenAIGatewayService,
|
gatewayService *service.OpenAIGatewayService,
|
||||||
concurrencyService *service.ConcurrencyService,
|
concurrencyService *service.ConcurrencyService,
|
||||||
billingCacheService *service.BillingCacheService,
|
billingCacheService *service.BillingCacheService,
|
||||||
|
apiKeyService *service.APIKeyService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *OpenAIGatewayHandler {
|
) *OpenAIGatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler(
|
|||||||
return &OpenAIGatewayHandler{
|
return &OpenAIGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
|
apiKeyService: apiKeyService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
}
|
}
|
||||||
@@ -299,13 +302,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: ip,
|
IPAddress: ip,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
|
|||||||
@@ -40,17 +40,48 @@ const (
|
|||||||
|
|
||||||
// URL 可用性 TTL(不可用 URL 的恢复时间)
|
// URL 可用性 TTL(不可用 URL 的恢复时间)
|
||||||
URLAvailabilityTTL = 5 * time.Minute
|
URLAvailabilityTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// Antigravity API 端点
|
||||||
|
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||||
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||||
var BaseURLs = []string{
|
var BaseURLs = []string{
|
||||||
"https://cloudcode-pa.googleapis.com", // prod (优先)
|
antigravityProdBaseURL, // prod (优先)
|
||||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
|
antigravityDailyBaseURL, // daily sandbox (备用)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BaseURL 默认 URL(保持向后兼容)
|
// BaseURL 默认 URL(保持向后兼容)
|
||||||
var BaseURL = BaseURLs[0]
|
var BaseURL = BaseURLs[0]
|
||||||
|
|
||||||
|
// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先)
|
||||||
|
func ForwardBaseURLs() []string {
|
||||||
|
if len(BaseURLs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
urls := append([]string(nil), BaseURLs...)
|
||||||
|
dailyIndex := -1
|
||||||
|
for i, url := range urls {
|
||||||
|
if url == antigravityDailyBaseURL {
|
||||||
|
dailyIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dailyIndex <= 0 {
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
reordered := make([]string, 0, len(urls))
|
||||||
|
reordered = append(reordered, urls[dailyIndex])
|
||||||
|
for i, url := range urls {
|
||||||
|
if i == dailyIndex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reordered = append(reordered, url)
|
||||||
|
}
|
||||||
|
return reordered
|
||||||
|
}
|
||||||
|
|
||||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
|
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
|
||||||
type URLAvailability struct {
|
type URLAvailability struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool {
|
|||||||
// GetAvailableURLs 返回可用的 URL 列表
|
// GetAvailableURLs 返回可用的 URL 列表
|
||||||
// 最近成功的 URL 优先,其他按默认顺序
|
// 最近成功的 URL 优先,其他按默认顺序
|
||||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||||
|
return u.GetAvailableURLsWithBase(BaseURLs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序)
|
||||||
|
// 最近成功的 URL 优先,其他按传入顺序
|
||||||
|
func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string {
|
||||||
u.mu.RLock()
|
u.mu.RLock()
|
||||||
defer u.mu.RUnlock()
|
defer u.mu.RUnlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := make([]string, 0, len(BaseURLs))
|
result := make([]string, 0, len(baseURLs))
|
||||||
|
|
||||||
// 如果有最近成功的 URL 且可用,放在最前面
|
// 如果有最近成功的 URL 且可用,放在最前面
|
||||||
if u.lastSuccess != "" {
|
if u.lastSuccess != "" {
|
||||||
expiry, exists := u.unavailable[u.lastSuccess]
|
found := false
|
||||||
if !exists || now.After(expiry) {
|
for _, url := range baseURLs {
|
||||||
result = append(result, u.lastSuccess)
|
if url == u.lastSuccess {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
expiry, exists := u.unavailable[u.lastSuccess]
|
||||||
|
if !exists || now.After(expiry) {
|
||||||
|
result = append(result, u.lastSuccess)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加其他可用的 URL(按默认顺序)
|
// 添加其他可用的 URL(按传入顺序)
|
||||||
for _, url := range BaseURLs {
|
for _, url := range baseURLs {
|
||||||
// 跳过已添加的 lastSuccess
|
// 跳过已添加的 lastSuccess
|
||||||
if url == u.lastSuccess {
|
if url == u.lastSuccess {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -44,11 +44,13 @@ type TransformOptions struct {
|
|||||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||||
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
|
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
|
||||||
IdentityPatch string
|
IdentityPatch string
|
||||||
|
EnableMCPXML bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultTransformOptions() TransformOptions {
|
func DefaultTransformOptions() TransformOptions {
|
||||||
return TransformOptions{
|
return TransformOptions{
|
||||||
EnableIdentityPatch: true,
|
EnableIdentityPatch: true,
|
||||||
|
EnableMCPXML: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
// 添加用户的 system prompt
|
// 添加用户的 system prompt
|
||||||
parts = append(parts, userSystemParts...)
|
parts = append(parts, userSystemParts...)
|
||||||
|
|
||||||
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
|
// 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议
|
||||||
if hasMCPTools(tools) {
|
if opts.EnableMCPXML && hasMCPTools(tools) {
|
||||||
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
|
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
|||||||
parts = append([]GeminiPart{{
|
parts = append([]GeminiPart{{
|
||||||
Text: "Thinking...",
|
Text: "Thinking...",
|
||||||
Thought: true,
|
Thought: true,
|
||||||
ThoughtSignature: dummyThoughtSignature,
|
ThoughtSignature: DummyThoughtSignature,
|
||||||
}}, parts...)
|
}}, parts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
|||||||
return contents, strippedThinking, nil
|
return contents, strippedThinking, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
|
// DummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
|
||||||
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复)
|
||||||
|
const DummyThoughtSignature = "skip_thought_signature_validator"
|
||||||
|
|
||||||
// buildParts 构建消息的 parts
|
// buildParts 构建消息的 parts
|
||||||
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
||||||
@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
// signature 处理:
|
// signature 处理:
|
||||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
|
||||||
part.ThoughtSignature = block.Signature
|
part.ThoughtSignature = block.Signature
|
||||||
} else if !allowDummyThought {
|
} else if !allowDummyThought {
|
||||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||||
@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
// Gemini 模型使用 dummy signature
|
// Gemini 模型使用 dummy signature
|
||||||
part.ThoughtSignature = dummyThoughtSignature
|
part.ThoughtSignature = DummyThoughtSignature
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
parts = append(parts, part)
|
||||||
|
|
||||||
@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
// tool_use 的 signature 处理:
|
// tool_use 的 signature 处理:
|
||||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
|
||||||
part.ThoughtSignature = block.Signature
|
part.ThoughtSignature = block.Signature
|
||||||
} else if allowDummyThought {
|
} else if allowDummyThought {
|
||||||
part.ThoughtSignature = dummyThoughtSignature
|
part.ThoughtSignature = DummyThoughtSignature
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
parts = append(parts, part)
|
||||||
|
|
||||||
@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildGenerationConfig 构建 generationConfig
|
// buildGenerationConfig 构建 generationConfig
|
||||||
|
const (
|
||||||
|
defaultMaxOutputTokens = 64000
|
||||||
|
maxOutputTokensUpperBound = 65000
|
||||||
|
maxOutputTokensClaude = 64000
|
||||||
|
)
|
||||||
|
|
||||||
|
func maxOutputTokensLimit(model string) int {
|
||||||
|
if strings.HasPrefix(model, "claude-") {
|
||||||
|
return maxOutputTokensClaude
|
||||||
|
}
|
||||||
|
return maxOutputTokensUpperBound
|
||||||
|
}
|
||||||
|
|
||||||
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||||
|
maxLimit := maxOutputTokensLimit(req.Model)
|
||||||
config := &GeminiGenerationConfig{
|
config := &GeminiGenerationConfig{
|
||||||
MaxOutputTokens: 64000, // 默认最大输出
|
MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出
|
||||||
StopSequences: DefaultStopSequences,
|
StopSequences: DefaultStopSequences,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.MaxOutputTokens > maxLimit {
|
||||||
|
config.MaxOutputTokens = maxLimit
|
||||||
|
}
|
||||||
|
|
||||||
// 其他参数
|
// 其他参数
|
||||||
if req.Temperature != nil {
|
if req.Temperature != nil {
|
||||||
config.Temperature = req.Temperature
|
config.Temperature = req.Temperature
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
|||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||||
}
|
}
|
||||||
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature {
|
if !parts[1].Thought || parts[1].ThoughtSignature != DummyThoughtSignature {
|
||||||
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
|
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
|
||||||
parts[1].Thought, parts[1].ThoughtSignature)
|
parts[1].Thought, parts[1].ThoughtSignature)
|
||||||
}
|
}
|
||||||
@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
|||||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||||
}
|
}
|
||||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
if parts[0].ThoughtSignature != DummyThoughtSignature {
|
||||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
t.Fatalf("expected dummy tool signature %q, got %q", DummyThoughtSignature, parts[0].ThoughtSignature)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,26 @@ const (
|
|||||||
BetaClaudeCode = "claude-code-20250219"
|
BetaClaudeCode = "claude-code-20250219"
|
||||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||||
|
BetaTokenCounting = "token-counting-2024-11-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||||
|
//
|
||||||
|
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
||||||
|
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||||
|
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||||
|
// request as a non-Claude-Code API request.
|
||||||
|
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||||
|
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||||
|
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||||
|
|
||||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
|||||||
|
|
||||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||||
var DefaultHeaders = map[string]string{
|
var DefaultHeaders = map[string]string{
|
||||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||||
|
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
|
||||||
|
"User-Agent": "claude-cli/2.1.22 (external, cli)",
|
||||||
"X-Stainless-Lang": "js",
|
"X-Stainless-Lang": "js",
|
||||||
"X-Stainless-Package-Version": "0.52.0",
|
"X-Stainless-Package-Version": "0.70.0",
|
||||||
"X-Stainless-OS": "Linux",
|
"X-Stainless-OS": "Linux",
|
||||||
"X-Stainless-Arch": "x64",
|
"X-Stainless-Arch": "arm64",
|
||||||
"X-Stainless-Runtime": "node",
|
"X-Stainless-Runtime": "node",
|
||||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
"X-Stainless-Runtime-Version": "v24.13.0",
|
||||||
"X-Stainless-Retry-Count": "0",
|
"X-Stainless-Retry-Count": "0",
|
||||||
"X-Stainless-Timeout": "60",
|
"X-Stainless-Timeout": "600",
|
||||||
"X-App": "cli",
|
"X-App": "cli",
|
||||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||||
}
|
}
|
||||||
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
|
|||||||
|
|
||||||
// DefaultTestModel 测试时使用的默认模型
|
// DefaultTestModel 测试时使用的默认模型
|
||||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
|
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
|
||||||
|
var ModelIDOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||||
|
"claude-opus-4-5": "claude-opus-4-5-20251101",
|
||||||
|
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
|
||||||
|
var ModelIDReverseOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeModelID 根据 Claude OAuth 规则映射模型
|
||||||
|
func NormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// DenormalizeModelID 将上游模型 ID 转换为短名
|
||||||
|
func DenormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDReverseOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ const (
|
|||||||
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
||||||
RetryCount Key = "ctx_retry_count"
|
RetryCount Key = "ctx_retry_count"
|
||||||
|
|
||||||
|
// AccountSwitchCount 表示请求过程中发生的账号切换次数
|
||||||
|
AccountSwitchCount Key = "ctx_account_switch_count"
|
||||||
|
|
||||||
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
|
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
|
||||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||||
|
|||||||
@@ -809,12 +809,21 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
scopeKey := string(scope)
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
result, err := client.ExecContext(
|
result, err := client.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
`UPDATE accounts SET
|
||||||
path,
|
extra = jsonb_set(
|
||||||
|
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
|
||||||
|
ARRAY['antigravity_quota_scopes', $1]::text[],
|
||||||
|
$2::jsonb,
|
||||||
|
true
|
||||||
|
),
|
||||||
|
updated_at = NOW(),
|
||||||
|
last_used_at = NOW()
|
||||||
|
WHERE id = $3 AND deleted_at IS NULL`,
|
||||||
|
scopeKey,
|
||||||
raw,
|
raw,
|
||||||
id,
|
id,
|
||||||
)
|
)
|
||||||
@@ -829,6 +838,7 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
@@ -849,12 +859,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
path := "{model_rate_limits," + scope + "}"
|
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
result, err := client.ExecContext(
|
result, err := client.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
`UPDATE accounts SET
|
||||||
path,
|
extra = jsonb_set(
|
||||||
|
jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
|
||||||
|
ARRAY['model_rate_limits', $1]::text[],
|
||||||
|
$2::jsonb,
|
||||||
|
true
|
||||||
|
),
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $3 AND deleted_at IS NULL`,
|
||||||
|
scope,
|
||||||
raw,
|
raw,
|
||||||
id,
|
id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
|||||||
SetKey(key.Key).
|
SetKey(key.Key).
|
||||||
SetName(key.Name).
|
SetName(key.Name).
|
||||||
SetStatus(key.Status).
|
SetStatus(key.Status).
|
||||||
SetNillableGroupID(key.GroupID)
|
SetNillableGroupID(key.GroupID).
|
||||||
|
SetQuota(key.Quota).
|
||||||
|
SetQuotaUsed(key.QuotaUsed).
|
||||||
|
SetNillableExpiresAt(key.ExpiresAt)
|
||||||
|
|
||||||
if len(key.IPWhitelist) > 0 {
|
if len(key.IPWhitelist) > 0 {
|
||||||
builder.SetIPWhitelist(key.IPWhitelist)
|
builder.SetIPWhitelist(key.IPWhitelist)
|
||||||
@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
apikey.FieldStatus,
|
apikey.FieldStatus,
|
||||||
apikey.FieldIPWhitelist,
|
apikey.FieldIPWhitelist,
|
||||||
apikey.FieldIPBlacklist,
|
apikey.FieldIPBlacklist,
|
||||||
|
apikey.FieldQuota,
|
||||||
|
apikey.FieldQuotaUsed,
|
||||||
|
apikey.FieldExpiresAt,
|
||||||
).
|
).
|
||||||
WithUser(func(q *dbent.UserQuery) {
|
WithUser(func(q *dbent.UserQuery) {
|
||||||
q.Select(
|
q.Select(
|
||||||
@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldImagePrice4k,
|
group.FieldImagePrice4k,
|
||||||
group.FieldClaudeCodeOnly,
|
group.FieldClaudeCodeOnly,
|
||||||
group.FieldFallbackGroupID,
|
group.FieldFallbackGroupID,
|
||||||
|
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||||
group.FieldModelRoutingEnabled,
|
group.FieldModelRoutingEnabled,
|
||||||
group.FieldModelRouting,
|
group.FieldModelRouting,
|
||||||
|
group.FieldMcpXMLInject,
|
||||||
|
group.FieldSupportedModelScopes,
|
||||||
)
|
)
|
||||||
}).
|
}).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
|||||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||||
SetName(key.Name).
|
SetName(key.Name).
|
||||||
SetStatus(key.Status).
|
SetStatus(key.Status).
|
||||||
|
SetQuota(key.Quota).
|
||||||
|
SetQuotaUsed(key.QuotaUsed).
|
||||||
SetUpdatedAt(now)
|
SetUpdatedAt(now)
|
||||||
if key.GroupID != nil {
|
if key.GroupID != nil {
|
||||||
builder.SetGroupID(*key.GroupID)
|
builder.SetGroupID(*key.GroupID)
|
||||||
@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
|||||||
builder.ClearGroupID()
|
builder.ClearGroupID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expiration time
|
||||||
|
if key.ExpiresAt != nil {
|
||||||
|
builder.SetExpiresAt(*key.ExpiresAt)
|
||||||
|
} else {
|
||||||
|
builder.ClearExpiresAt()
|
||||||
|
}
|
||||||
|
|
||||||
// IP 限制字段
|
// IP 限制字段
|
||||||
if len(key.IPWhitelist) > 0 {
|
if len(key.IPWhitelist) > 0 {
|
||||||
builder.SetIPWhitelist(key.IPWhitelist)
|
builder.SetIPWhitelist(key.IPWhitelist)
|
||||||
@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
|||||||
return keys, nil
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||||
|
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
// Use raw SQL for atomic increment to avoid race conditions
|
||||||
|
// First get current value
|
||||||
|
m, err := r.activeQuery().
|
||||||
|
Where(apikey.IDEQ(id)).
|
||||||
|
Select(apikey.FieldQuotaUsed).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return 0, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newValue := m.QuotaUsed + amount
|
||||||
|
|
||||||
|
// Update with new value
|
||||||
|
affected, err := r.client.APIKey.Update().
|
||||||
|
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||||
|
SetQuotaUsed(newValue).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return 0, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return newValue, nil
|
||||||
|
}
|
||||||
|
|
||||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
|||||||
CreatedAt: m.CreatedAt,
|
CreatedAt: m.CreatedAt,
|
||||||
UpdatedAt: m.UpdatedAt,
|
UpdatedAt: m.UpdatedAt,
|
||||||
GroupID: m.GroupID,
|
GroupID: m.GroupID,
|
||||||
|
Quota: m.Quota,
|
||||||
|
QuotaUsed: m.QuotaUsed,
|
||||||
|
ExpiresAt: m.ExpiresAt,
|
||||||
}
|
}
|
||||||
if m.Edges.User != nil {
|
if m.Edges.User != nil {
|
||||||
out.User = userEntityToService(m.Edges.User)
|
out.User = userEntityToService(m.Edges.User)
|
||||||
@@ -409,28 +462,31 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.Group{
|
return &service.Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Description: derefString(g.Description),
|
Description: derefString(g.Description),
|
||||||
Platform: g.Platform,
|
Platform: g.Platform,
|
||||||
RateMultiplier: g.RateMultiplier,
|
RateMultiplier: g.RateMultiplier,
|
||||||
IsExclusive: g.IsExclusive,
|
IsExclusive: g.IsExclusive,
|
||||||
Status: g.Status,
|
Status: g.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: g.SubscriptionType,
|
SubscriptionType: g.SubscriptionType,
|
||||||
DailyLimitUSD: g.DailyLimitUsd,
|
DailyLimitUSD: g.DailyLimitUsd,
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||||
ImagePrice1K: g.ImagePrice1k,
|
ImagePrice1K: g.ImagePrice1k,
|
||||||
ImagePrice2K: g.ImagePrice2k,
|
ImagePrice2K: g.ImagePrice2k,
|
||||||
ImagePrice4K: g.ImagePrice4k,
|
ImagePrice4K: g.ImagePrice4k,
|
||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
ModelRouting: g.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRouting: g.ModelRouting,
|
||||||
CreatedAt: g.CreatedAt,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
UpdatedAt: g.UpdatedAt,
|
MCPXMLInject: g.McpXMLInject,
|
||||||
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
|
CreatedAt: g.CreatedAt,
|
||||||
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||||
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
|
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置支持的模型系列(始终设置,空数组表示不限制)
|
||||||
|
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||||
|
|
||||||
created, err := builder.Save(ctx)
|
created, err := builder.Save(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
groupIn.ID = created.ID
|
groupIn.ID = created.ID
|
||||||
@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return groupEntityToService(m), nil
|
return groupEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
|
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||||
|
|
||||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||||
if groupIn.FallbackGroupID != nil {
|
if groupIn.FallbackGroupID != nil {
|
||||||
@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
} else {
|
} else {
|
||||||
builder = builder.ClearFallbackGroupID()
|
builder = builder.ClearFallbackGroupID()
|
||||||
}
|
}
|
||||||
|
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
|
||||||
|
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
|
||||||
|
} else {
|
||||||
|
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
}
|
||||||
|
|
||||||
// 处理 ModelRouting:nil 时清除,否则设置
|
// 处理 ModelRouting:nil 时清除,否则设置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
builder = builder.ClearModelRouting()
|
builder = builder.ClearModelRouting()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
|
||||||
|
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||||
|
|
||||||
updated, err := builder.Save(ctx)
|
updated, err := builder.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||||
@@ -425,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
|||||||
|
|
||||||
return counts, nil
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||||
|
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(
|
||||||
|
ctx,
|
||||||
|
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
|
||||||
|
pq.Array(groupIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var accountIDs []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var accountID int64
|
||||||
|
if err := rows.Scan(&accountID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
accountIDs = append(accountIDs, accountID)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
|
||||||
|
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
|
||||||
|
_, err := r.sql.ExecContext(
|
||||||
|
ctx,
|
||||||
|
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
|
||||||
|
SELECT unnest($1::bigint[]), $2, 50, NOW()
|
||||||
|
ON CONFLICT (account_id, group_id) DO NOTHING`,
|
||||||
|
pq.Array(accountIDs),
|
||||||
|
groupID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送调度器事件
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
|
|||||||
upstream_529_count,
|
upstream_529_count,
|
||||||
|
|
||||||
token_consumed,
|
token_consumed,
|
||||||
|
account_switch_count,
|
||||||
qps,
|
qps,
|
||||||
tps,
|
tps,
|
||||||
|
|
||||||
@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
|
|||||||
$1,$2,$3,$4,
|
$1,$2,$3,$4,
|
||||||
$5,$6,$7,$8,
|
$5,$6,$7,$8,
|
||||||
$9,$10,$11,
|
$9,$10,$11,
|
||||||
$12,$13,$14,
|
$12,$13,$14,$15,
|
||||||
$15,$16,$17,$18,$19,$20,
|
$16,$17,$18,$19,$20,$21,
|
||||||
$21,$22,$23,$24,$25,$26,
|
$22,$23,$24,$25,$26,$27,
|
||||||
$27,$28,$29,$30,
|
$28,$29,$30,$31,
|
||||||
$31,$32,
|
$32,$33,
|
||||||
$33,$34,
|
$34,$35,
|
||||||
$35,$36,$37,
|
$36,$37,$38,
|
||||||
$38,$39
|
$39,$40
|
||||||
)`
|
)`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(
|
_, err := r.db.ExecContext(
|
||||||
@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
|
|||||||
input.Upstream529Count,
|
input.Upstream529Count,
|
||||||
|
|
||||||
input.TokenConsumed,
|
input.TokenConsumed,
|
||||||
|
input.AccountSwitchCount,
|
||||||
opsNullFloat64(input.QPS),
|
opsNullFloat64(input.QPS),
|
||||||
opsNullFloat64(input.TPS),
|
opsNullFloat64(input.TPS),
|
||||||
|
|
||||||
@@ -177,7 +179,8 @@ SELECT
|
|||||||
db_conn_waiting,
|
db_conn_waiting,
|
||||||
|
|
||||||
goroutine_count,
|
goroutine_count,
|
||||||
concurrency_queue_depth
|
concurrency_queue_depth,
|
||||||
|
account_switch_count
|
||||||
FROM ops_system_metrics
|
FROM ops_system_metrics
|
||||||
WHERE window_minutes = $1
|
WHERE window_minutes = $1
|
||||||
AND platform IS NULL
|
AND platform IS NULL
|
||||||
@@ -199,6 +202,7 @@ LIMIT 1`
|
|||||||
var dbWaiting sql.NullInt64
|
var dbWaiting sql.NullInt64
|
||||||
var goroutines sql.NullInt64
|
var goroutines sql.NullInt64
|
||||||
var queueDepth sql.NullInt64
|
var queueDepth sql.NullInt64
|
||||||
|
var accountSwitchCount sql.NullInt64
|
||||||
|
|
||||||
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||||
&out.ID,
|
&out.ID,
|
||||||
@@ -217,6 +221,7 @@ LIMIT 1`
|
|||||||
&dbWaiting,
|
&dbWaiting,
|
||||||
&goroutines,
|
&goroutines,
|
||||||
&queueDepth,
|
&queueDepth,
|
||||||
|
&accountSwitchCount,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -273,6 +278,10 @@ LIMIT 1`
|
|||||||
v := int(queueDepth.Int64)
|
v := int(queueDepth.Int64)
|
||||||
out.ConcurrencyQueueDepth = &v
|
out.ConcurrencyQueueDepth = &v
|
||||||
}
|
}
|
||||||
|
if accountSwitchCount.Valid {
|
||||||
|
v := accountSwitchCount.Int64
|
||||||
|
out.AccountSwitchCount = &v
|
||||||
|
}
|
||||||
|
|
||||||
return &out, nil
|
return &out, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,18 +56,44 @@ error_buckets AS (
|
|||||||
AND COALESCE(status_code, 0) >= 400
|
AND COALESCE(status_code, 0) >= 400
|
||||||
GROUP BY 1
|
GROUP BY 1
|
||||||
),
|
),
|
||||||
|
switch_buckets AS (
|
||||||
|
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||||
|
COALESCE(SUM(CASE
|
||||||
|
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
|
||||||
|
ELSE 0
|
||||||
|
END), 0) AS switch_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements(
|
||||||
|
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
|
||||||
|
) AS ev
|
||||||
|
` + errorWhere + `
|
||||||
|
AND upstream_errors IS NOT NULL
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
combined AS (
|
combined AS (
|
||||||
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
SELECT
|
||||||
COALESCE(u.success_count, 0) AS success_count,
|
bucket,
|
||||||
COALESCE(e.error_count, 0) AS error_count,
|
SUM(success_count) AS success_count,
|
||||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
SUM(error_count) AS error_count,
|
||||||
FROM usage_buckets u
|
SUM(token_consumed) AS token_consumed,
|
||||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
SUM(switch_count) AS switch_count
|
||||||
|
FROM (
|
||||||
|
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
|
||||||
|
FROM usage_buckets
|
||||||
|
UNION ALL
|
||||||
|
SELECT bucket, 0, error_count, 0, 0
|
||||||
|
FROM error_buckets
|
||||||
|
UNION ALL
|
||||||
|
SELECT bucket, 0, 0, 0, switch_count
|
||||||
|
FROM switch_buckets
|
||||||
|
) t
|
||||||
|
GROUP BY bucket
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
bucket,
|
bucket,
|
||||||
(success_count + error_count) AS request_count,
|
(success_count + error_count) AS request_count,
|
||||||
token_consumed
|
token_consumed,
|
||||||
|
switch_count
|
||||||
FROM combined
|
FROM combined
|
||||||
ORDER BY bucket ASC`
|
ORDER BY bucket ASC`
|
||||||
|
|
||||||
@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
|
|||||||
var bucket time.Time
|
var bucket time.Time
|
||||||
var requests int64
|
var requests int64
|
||||||
var tokens sql.NullInt64
|
var tokens sql.NullInt64
|
||||||
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
|
var switches sql.NullInt64
|
||||||
|
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenConsumed := int64(0)
|
tokenConsumed := int64(0)
|
||||||
if tokens.Valid {
|
if tokens.Valid {
|
||||||
tokenConsumed = tokens.Int64
|
tokenConsumed = tokens.Int64
|
||||||
}
|
}
|
||||||
|
switchCount := int64(0)
|
||||||
|
if switches.Valid {
|
||||||
|
switchCount = switches.Int64
|
||||||
|
}
|
||||||
|
|
||||||
denom := float64(bucketSeconds)
|
denom := float64(bucketSeconds)
|
||||||
if denom <= 0 {
|
if denom <= 0 {
|
||||||
@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
|
|||||||
BucketStart: bucket.UTC(),
|
BucketStart: bucket.UTC(),
|
||||||
RequestCount: requests,
|
RequestCount: requests,
|
||||||
TokenConsumed: tokenConsumed,
|
TokenConsumed: tokenConsumed,
|
||||||
|
SwitchCount: switchCount,
|
||||||
QPS: qps,
|
QPS: qps,
|
||||||
TPS: tps,
|
TPS: tps,
|
||||||
})
|
})
|
||||||
@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
|
|||||||
BucketStart: cursor,
|
BucketStart: cursor,
|
||||||
RequestCount: 0,
|
RequestCount: 0,
|
||||||
TokenConsumed: 0,
|
TokenConsumed: 0,
|
||||||
|
SwitchCount: 0,
|
||||||
QPS: 0,
|
QPS: 0,
|
||||||
TPS: 0,
|
TPS: 0,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||||
}
|
}
|
||||||
return &proxyProbeService{
|
return &proxyProbeService{
|
||||||
ipInfoURL: defaultIPInfoURL,
|
|
||||||
insecureSkipVerify: insecure,
|
insecureSkipVerify: insecure,
|
||||||
allowPrivateHosts: allowPrivate,
|
allowPrivateHosts: allowPrivate,
|
||||||
validateResolvedIP: validateResolvedIP,
|
validateResolvedIP: validateResolvedIP,
|
||||||
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
|
||||||
defaultProxyProbeTimeout = 30 * time.Second
|
defaultProxyProbeTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// probeURLs 按优先级排列的探测 URL 列表
|
||||||
|
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
|
||||||
|
var probeURLs = []struct {
|
||||||
|
url string
|
||||||
|
parser string // "ip-api" or "httpbin"
|
||||||
|
}{
|
||||||
|
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
|
||||||
|
{"http://httpbin.org/ip", "httpbin"},
|
||||||
|
}
|
||||||
|
|
||||||
type proxyProbeService struct {
|
type proxyProbeService struct {
|
||||||
ipInfoURL string
|
|
||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
allowPrivateHosts bool
|
allowPrivateHosts bool
|
||||||
validateResolvedIP bool
|
validateResolvedIP bool
|
||||||
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for _, probe := range probeURLs {
|
||||||
|
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
|
||||||
|
if err == nil {
|
||||||
|
return exitInfo, latencyMs, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parser {
|
||||||
|
case "ip-api":
|
||||||
|
return s.parseIPAPI(body, latencyMs)
|
||||||
|
case "httpbin":
|
||||||
|
return s.parseHTTPBin(body, latencyMs)
|
||||||
|
default:
|
||||||
|
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||||
var ipInfo struct {
|
var ipInfo struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
CountryCode string `json:"countryCode"`
|
CountryCode string `json:"countryCode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
preview := string(body)
|
||||||
|
if len(preview) > 200 {
|
||||||
|
preview = preview[:200] + "..."
|
||||||
|
}
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
|
||||||
}
|
}
|
||||||
if strings.ToLower(ipInfo.Status) != "success" {
|
if strings.ToLower(ipInfo.Status) != "success" {
|
||||||
if ipInfo.Message == "" {
|
if ipInfo.Message == "" {
|
||||||
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
CountryCode: ipInfo.CountryCode,
|
CountryCode: ipInfo.CountryCode,
|
||||||
}, latencyMs, nil
|
}, latencyMs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||||
|
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
|
||||||
|
var result struct {
|
||||||
|
Origin string `json:"origin"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Origin == "" {
|
||||||
|
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
|
||||||
|
}
|
||||||
|
return &service.ProxyExitInfo{
|
||||||
|
IP: result.Origin,
|
||||||
|
}, latencyMs, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
|
|||||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
s.prober = &proxyProbeService{
|
s.prober = &proxyProbeService{
|
||||||
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
|
||||||
allowPrivateHosts: true,
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
|||||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
|
||||||
seen := make(chan string, 1)
|
|
||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seen <- r.RequestURI
|
// 检查是否是 ip-api 请求
|
||||||
w.Header().Set("Content-Type", "application/json")
|
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 其他请求返回错误
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
|||||||
require.Equal(s.T(), "r", info.Region)
|
require.Equal(s.T(), "r", info.Region)
|
||||||
require.Equal(s.T(), "cc", info.Country)
|
require.Equal(s.T(), "cc", info.Country)
|
||||||
require.Equal(s.T(), "CC", info.CountryCode)
|
require.Equal(s.T(), "CC", info.CountryCode)
|
||||||
|
|
||||||
// Verify proxy received the request
|
|
||||||
select {
|
|
||||||
case uri := <-seen:
|
|
||||||
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
|
||||||
default:
|
|
||||||
require.Fail(s.T(), "expected proxy to receive request")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// ip-api 失败
|
||||||
|
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// httpbin 成功
|
||||||
|
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
|
||||||
|
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
|
||||||
|
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||||
|
require.Equal(s.T(), "5.6.7.8", info.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
|
||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
require.Error(s.T(), err)
|
require.Error(s.T(), err)
|
||||||
require.ErrorContains(s.T(), err, "status: 503")
|
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||||
_, _ = io.WriteString(w, "not-json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, "not-json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// httpbin 也返回无效响应
|
||||||
|
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, "not-json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
require.Error(s.T(), err)
|
require.Error(s.T(), err)
|
||||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
|
||||||
s.prober.ipInfoURL = "://invalid-url"
|
|
||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
|
||||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||||
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
|||||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
|
||||||
|
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
|
||||||
|
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), int64(100), latencyMs)
|
||||||
|
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||||
|
require.Equal(s.T(), "Beijing", info.City)
|
||||||
|
require.Equal(s.T(), "Beijing", info.Region)
|
||||||
|
require.Equal(s.T(), "China", info.Country)
|
||||||
|
require.Equal(s.T(), "CN", info.CountryCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
|
||||||
|
body := []byte(`{"status":"fail","message":"rate limited"}`)
|
||||||
|
_, _, err := s.prober.parseIPAPI(body, 100)
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
|
||||||
|
body := []byte(`{"origin": "9.8.7.6"}`)
|
||||||
|
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), int64(50), latencyMs)
|
||||||
|
require.Equal(s.T(), "9.8.7.6", info.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
|
||||||
|
body := []byte(`{"origin": ""}`)
|
||||||
|
_, _, err := s.prober.parseHTTPBin(body, 50)
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "no IP found")
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
|
|||||||
return redeemCodeEntitiesToService(codes), nil
|
return redeemCodeEntitiesToService(codes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListByUserPaginated returns paginated balance/concurrency history for a user.
|
||||||
|
// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
|
||||||
|
func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
q := r.client.RedeemCode.Query().
|
||||||
|
Where(redeemcode.UsedByEQ(userID))
|
||||||
|
|
||||||
|
// Optional type filter
|
||||||
|
if codeType != "" {
|
||||||
|
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := q.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
codes, err := q.
|
||||||
|
WithGroup().
|
||||||
|
Offset(params.Offset()).
|
||||||
|
Limit(params.Limit()).
|
||||||
|
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
|
||||||
|
func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||||
|
var result []struct {
|
||||||
|
Sum float64 `json:"sum"`
|
||||||
|
}
|
||||||
|
err := r.client.RedeemCode.Query().
|
||||||
|
Where(
|
||||||
|
redeemcode.UsedByEQ(userID),
|
||||||
|
redeemcode.ValueGT(0),
|
||||||
|
redeemcode.TypeIn("balance", "admin_balance"),
|
||||||
|
).
|
||||||
|
Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
|
||||||
|
Scan(ctx, &result)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return result[0].Sum, nil
|
||||||
|
}
|
||||||
|
|
||||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||||||
|
|
||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
@@ -111,21 +111,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration_ms,
|
duration_ms,
|
||||||
first_token_ms,
|
first_token_ms,
|
||||||
user_agent,
|
user_agent,
|
||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
created_at
|
reasoning_effort,
|
||||||
) VALUES (
|
created_at
|
||||||
$1, $2, $3, $4, $5,
|
) VALUES (
|
||||||
$6, $7,
|
$1, $2, $3, $4, $5,
|
||||||
$8, $9, $10, $11,
|
$6, $7,
|
||||||
$12, $13,
|
$8, $9, $10, $11,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$12, $13,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
|
$14, $15, $16, $17, $18, $19,
|
||||||
)
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
)
|
||||||
RETURNING id, created_at
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`
|
RETURNING id, created_at
|
||||||
|
`
|
||||||
|
|
||||||
groupID := nullInt64(log.GroupID)
|
groupID := nullInt64(log.GroupID)
|
||||||
subscriptionID := nullInt64(log.SubscriptionID)
|
subscriptionID := nullInt64(log.SubscriptionID)
|
||||||
@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
userAgent := nullString(log.UserAgent)
|
userAgent := nullString(log.UserAgent)
|
||||||
ipAddress := nullString(log.IPAddress)
|
ipAddress := nullString(log.IPAddress)
|
||||||
imageSize := nullString(log.ImageSize)
|
imageSize := nullString(log.ImageSize)
|
||||||
|
reasoningEffort := nullString(log.ReasoningEffort)
|
||||||
|
|
||||||
var requestIDArg any
|
var requestIDArg any
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
ipAddress,
|
ipAddress,
|
||||||
log.ImageCount,
|
log.ImageCount,
|
||||||
imageSize,
|
imageSize,
|
||||||
|
reasoningEffort,
|
||||||
createdAt,
|
createdAt,
|
||||||
}
|
}
|
||||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||||
@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
ipAddress sql.NullString
|
ipAddress sql.NullString
|
||||||
imageCount int
|
imageCount int
|
||||||
imageSize sql.NullString
|
imageSize sql.NullString
|
||||||
|
reasoningEffort sql.NullString
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&ipAddress,
|
&ipAddress,
|
||||||
&imageCount,
|
&imageCount,
|
||||||
&imageSize,
|
&imageSize,
|
||||||
|
&reasoningEffort,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if imageSize.Valid {
|
if imageSize.Valid {
|
||||||
log.ImageSize = &imageSize.String
|
log.ImageSize = &imageSize.String
|
||||||
}
|
}
|
||||||
|
if reasoningEffort.Valid {
|
||||||
|
log.ReasoningEffort = &reasoningEffort.String
|
||||||
|
}
|
||||||
|
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"status": "active",
|
"status": "active",
|
||||||
"ip_whitelist": null,
|
"ip_whitelist": null,
|
||||||
"ip_blacklist": null,
|
"ip_blacklist": null,
|
||||||
|
"quota": 0,
|
||||||
|
"quota_used": 0,
|
||||||
|
"expires_at": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"status": "active",
|
"status": "active",
|
||||||
"ip_whitelist": null,
|
"ip_whitelist": null,
|
||||||
"ip_blacklist": null,
|
"ip_blacklist": null,
|
||||||
|
"quota": 0,
|
||||||
|
"quota_used": 0,
|
||||||
|
"expires_at": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -180,6 +186,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"image_price_4k": null,
|
"image_price_4k": null,
|
||||||
"claude_code_only": false,
|
"claude_code_only": false,
|
||||||
"fallback_group_id": null,
|
"fallback_group_id": null,
|
||||||
|
"fallback_group_id_on_invalid_request": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -488,6 +495,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"fallback_model_openai": "gpt-4o",
|
"fallback_model_openai": "gpt-4o",
|
||||||
"enable_identity_patch": true,
|
"enable_identity_patch": true,
|
||||||
"identity_patch_prompt": "",
|
"identity_patch_prompt": "",
|
||||||
|
"invitation_code_enabled": false,
|
||||||
"home_content": "",
|
"home_content": "",
|
||||||
"hide_ccs_import_button": false,
|
"hide_ccs_import_button": false,
|
||||||
"purchase_subscription_enabled": false,
|
"purchase_subscription_enabled": false,
|
||||||
@@ -600,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
@@ -880,6 +888,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubAccountRepo struct {
|
type stubAccountRepo struct {
|
||||||
bulkUpdateIDs []int64
|
bulkUpdateIDs []int64
|
||||||
}
|
}
|
||||||
@@ -1141,6 +1157,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit
|
|||||||
return append([]service.RedeemCode(nil), codes...), nil
|
return append([]service.RedeemCode(nil), codes...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct {
|
type stubUserSubscriptionRepo struct {
|
||||||
byUser map[int64][]service.UserSubscription
|
byUser map[int64][]service.UserSubscription
|
||||||
activeByUser map[int64][]service.UserSubscription
|
activeByUser map[int64][]service.UserSubscription
|
||||||
@@ -1425,6 +1449,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUsageLogRepo struct {
|
type stubUsageLogRepo struct {
|
||||||
userLogs map[int64][]service.UsageLog
|
userLogs map[int64][]service.UsageLog
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
|
|
||||||
// 检查API key是否激活
|
// 检查API key是否激活
|
||||||
if !apiKey.IsActive() {
|
if !apiKey.IsActive() {
|
||||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
// Provide more specific error message based on status
|
||||||
|
switch apiKey.Status {
|
||||||
|
case service.StatusAPIKeyQuotaExhausted:
|
||||||
|
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||||
|
case service.StatusAPIKeyExpired:
|
||||||
|
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||||
|
default:
|
||||||
|
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查API Key是否过期(即使状态是active,也要检查时间)
|
||||||
|
if apiKey.IsExpired() {
|
||||||
|
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查API Key配额是否耗尽
|
||||||
|
if apiKey.IsQuotaExhausted() {
|
||||||
|
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
|||||||
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
|
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
apiKeyString := extractAPIKeyFromRequest(c)
|
apiKeyString := extractAPIKeyForGoogle(c)
|
||||||
if apiKeyString == "" {
|
if apiKeyString == "" {
|
||||||
abortWithGoogleError(c, 401, "API key is required")
|
abortWithGoogleError(c, 401, "API key is required")
|
||||||
return
|
return
|
||||||
@@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractAPIKeyFromRequest(c *gin.Context) string {
|
// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints.
|
||||||
authHeader := c.GetHeader("Authorization")
|
// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key
|
||||||
if authHeader != "" {
|
// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints.
|
||||||
parts := strings.SplitN(authHeader, " ", 2)
|
func extractAPIKeyForGoogle(c *gin.Context) string {
|
||||||
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" {
|
// 1) preferred: Gemini native header
|
||||||
return strings.TrimSpace(parts[1])
|
if k := strings.TrimSpace(c.GetHeader("x-goog-api-key")); k != "" {
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) fallback: Authorization: Bearer <key>
|
||||||
|
auth := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||||
|
if auth != "" {
|
||||||
|
parts := strings.SplitN(auth, " ", 2)
|
||||||
|
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||||
|
if k := strings.TrimSpace(parts[1]); k != "" {
|
||||||
|
return k
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
|
|
||||||
return v
|
// 3) x-api-key header (backward compatibility)
|
||||||
}
|
if k := strings.TrimSpace(c.GetHeader("x-api-key")); k != "" {
|
||||||
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
|
return k
|
||||||
return v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4) query parameter key (for specific paths)
|
||||||
if allowGoogleQueryKey(c.Request.URL.Path) {
|
if allowGoogleQueryKey(c.Request.URL.Path) {
|
||||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s
|
|||||||
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type googleErrorResponse struct {
|
type googleErrorResponse struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
|
|||||||
@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct {
|
type stubUserSubscriptionRepo struct {
|
||||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||||
|
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||||
|
|
||||||
// User attribute values
|
// User attribute values
|
||||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ func RegisterAuthRoutes(
|
|||||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
}), h.Auth.ValidatePromoCode)
|
}), h.Auth.ValidatePromoCode)
|
||||||
|
// 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||||
|
auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.ValidateInvitationCode)
|
||||||
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
||||||
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
|||||||
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetClaudeUserID() string {
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
|||||||
"system": []map[string]any{
|
"system": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
"text": claudeCodeSystemPrompt,
|
||||||
"cache_control": map[string]string{
|
"cache_control": map[string]string{
|
||||||
"type": "ephemeral",
|
"type": "ephemeral",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ type AdminService interface {
|
|||||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
|
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||||
|
// codeType is optional - pass empty string to return all types.
|
||||||
|
// Also returns totalRecharged (sum of all positive balance top-ups).
|
||||||
|
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
|
||||||
|
|
||||||
// Group management
|
// Group management
|
||||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||||
@@ -107,9 +111,16 @@ type CreateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled bool // 是否启用模型路由
|
ModelRoutingEnabled bool // 是否启用模型路由
|
||||||
|
MCPXMLInject *bool
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes []string
|
||||||
|
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||||
|
CopyAccountsFromGroupIDs []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateGroupInput struct {
|
type UpdateGroupInput struct {
|
||||||
@@ -129,9 +140,16 @@ type UpdateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||||
|
MCPXMLInject *bool
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes *[]string
|
||||||
|
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||||
|
CopyAccountsFromGroupIDs []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateAccountInput struct {
|
type CreateAccountInput struct {
|
||||||
@@ -522,6 +540,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||||
|
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||||
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
|
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||||
|
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
return codes, result.Total, totalRecharged, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Group management implementations
|
// Group management implementations
|
||||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
@@ -571,28 +604,88 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
|
||||||
|
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
|
||||||
|
fallbackOnInvalidRequest = nil
|
||||||
|
}
|
||||||
|
// 校验无效请求兜底分组
|
||||||
|
if fallbackOnInvalidRequest != nil {
|
||||||
|
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
|
||||||
|
mcpXMLInject := true
|
||||||
|
if input.MCPXMLInject != nil {
|
||||||
|
mcpXMLInject = *input.MCPXMLInject
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||||
|
var accountIDsToCopy []int64
|
||||||
|
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||||
|
// 去重源分组 IDs
|
||||||
|
seen := make(map[int64]struct{})
|
||||||
|
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||||
|
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||||
|
if _, exists := seen[srcGroupID]; !exists {
|
||||||
|
seen[srcGroupID] = struct{}{}
|
||||||
|
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验源分组的平台是否与新分组一致
|
||||||
|
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||||
|
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||||
|
}
|
||||||
|
if srcGroup.Platform != platform {
|
||||||
|
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有源分组的账号(去重)
|
||||||
|
var err error
|
||||||
|
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
group := &Group{
|
group := &Group{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
RateMultiplier: input.RateMultiplier,
|
RateMultiplier: input.RateMultiplier,
|
||||||
IsExclusive: input.IsExclusive,
|
IsExclusive: input.IsExclusive,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
SubscriptionType: subscriptionType,
|
SubscriptionType: subscriptionType,
|
||||||
DailyLimitUSD: dailyLimit,
|
DailyLimitUSD: dailyLimit,
|
||||||
WeeklyLimitUSD: weeklyLimit,
|
WeeklyLimitUSD: weeklyLimit,
|
||||||
MonthlyLimitUSD: monthlyLimit,
|
MonthlyLimitUSD: monthlyLimit,
|
||||||
ImagePrice1K: imagePrice1K,
|
ImagePrice1K: imagePrice1K,
|
||||||
ImagePrice2K: imagePrice2K,
|
ImagePrice2K: imagePrice2K,
|
||||||
ImagePrice4K: imagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||||
FallbackGroupID: input.FallbackGroupID,
|
FallbackGroupID: input.FallbackGroupID,
|
||||||
ModelRouting: input.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||||
|
ModelRouting: input.ModelRouting,
|
||||||
|
MCPXMLInject: mcpXMLInject,
|
||||||
|
SupportedModelScopes: input.SupportedModelScopes,
|
||||||
}
|
}
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果有需要复制的账号,绑定到新分组
|
||||||
|
if len(accountIDsToCopy) > 0 {
|
||||||
|
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
|
||||||
|
}
|
||||||
|
group.AccountCount = int64(len(accountIDsToCopy))
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -650,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
|
||||||
|
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||||
|
// platform/subscriptionType: 当前分组的有效平台/订阅类型
|
||||||
|
// fallbackGroupID: 兜底分组 ID
|
||||||
|
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
|
||||||
|
if platform != PlatformAnthropic && platform != PlatformAntigravity {
|
||||||
|
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
}
|
||||||
|
if subscriptionType == SubscriptionTypeSubscription {
|
||||||
|
return fmt.Errorf("subscription groups cannot set invalid request fallback")
|
||||||
|
}
|
||||||
|
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||||
|
return fmt.Errorf("cannot set self as invalid request fallback group")
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fallback group not found: %w", err)
|
||||||
|
}
|
||||||
|
if fallbackGroup.Platform != PlatformAnthropic {
|
||||||
|
return fmt.Errorf("fallback group must be anthropic platform")
|
||||||
|
}
|
||||||
|
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
|
||||||
|
return fmt.Errorf("fallback group cannot be subscription type")
|
||||||
|
}
|
||||||
|
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||||
group, err := s.groupRepo.GetByID(ctx, id)
|
group, err := s.groupRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -716,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
group.FallbackGroupID = nil
|
group.FallbackGroupID = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
|
||||||
|
if input.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
if *input.FallbackGroupIDOnInvalidRequest > 0 {
|
||||||
|
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
|
||||||
|
} else {
|
||||||
|
fallbackOnInvalidRequest = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fallbackOnInvalidRequest != nil {
|
||||||
|
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
|
||||||
|
|
||||||
// 模型路由配置
|
// 模型路由配置
|
||||||
if input.ModelRouting != nil {
|
if input.ModelRouting != nil {
|
||||||
@@ -724,10 +862,66 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.ModelRoutingEnabled != nil {
|
if input.ModelRoutingEnabled != nil {
|
||||||
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||||
}
|
}
|
||||||
|
if input.MCPXMLInject != nil {
|
||||||
|
group.MCPXMLInject = *input.MCPXMLInject
|
||||||
|
}
|
||||||
|
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
if input.SupportedModelScopes != nil {
|
||||||
|
group.SupportedModelScopes = *input.SupportedModelScopes
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||||
|
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||||
|
// 去重源分组 IDs
|
||||||
|
seen := make(map[int64]struct{})
|
||||||
|
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||||
|
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||||
|
// 校验:源分组不能是自身
|
||||||
|
if srcGroupID == id {
|
||||||
|
return nil, fmt.Errorf("cannot copy accounts from self")
|
||||||
|
}
|
||||||
|
// 去重
|
||||||
|
if _, exists := seen[srcGroupID]; !exists {
|
||||||
|
seen[srcGroupID] = struct{}{}
|
||||||
|
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验源分组的平台是否与当前分组一致
|
||||||
|
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||||
|
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||||
|
}
|
||||||
|
if srcGroup.Platform != group.Platform {
|
||||||
|
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有源分组的账号(去重)
|
||||||
|
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先清空当前分组的所有账号绑定
|
||||||
|
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 再绑定源分组的账号
|
||||||
|
if len(accountIDsToCopy) > 0 {
|
||||||
|
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.authCacheInvalidator != nil {
|
if s.authCacheInvalidator != nil {
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
|||||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||||
|
panic("unexpected BindAccountsToGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||||
|
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
type proxyRepoStub struct {
|
type proxyRepoStub struct {
|
||||||
deleteErr error
|
deleteErr error
|
||||||
countErr error
|
countErr error
|
||||||
@@ -274,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int
|
|||||||
panic("unexpected ListByUser call")
|
panic("unexpected ListByUser call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserPaginated call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||||
|
panic("unexpected SumPositiveBalanceByUser call")
|
||||||
|
}
|
||||||
|
|
||||||
type subscriptionInvalidateCall struct {
|
type subscriptionInvalidateCall struct {
|
||||||
userID int64
|
userID int64
|
||||||
groupID int64
|
groupID int64
|
||||||
|
|||||||
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
|
|||||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||||
|
panic("unexpected BindAccountsToGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||||
|
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||||
repo := &groupRepoStubForAdmin{}
|
repo := &groupRepoStubForAdmin{}
|
||||||
@@ -378,3 +386,390 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
|||||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||||
|
panic("unexpected BindAccountsToGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||||
|
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupRepoStubForInvalidRequestFallback struct {
|
||||||
|
groups map[int64]*Group
|
||||||
|
created *Group
|
||||||
|
updated *Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||||
|
s.created = g
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||||
|
s.updated = g
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||||
|
return s.GetByIDLite(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||||
|
if g, ok := s.groups[id]; ok {
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
return nil, ErrGroupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||||
|
panic("unexpected DeleteCascade call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected List call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListWithFilters call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||||
|
panic("unexpected ListActive call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||||
|
panic("unexpected ListActiveByPlatform call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByName call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected GetAccountCount call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||||
|
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||||
|
panic("unexpected BindAccountsToGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeSubscription,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fallback *Group
|
||||||
|
wantMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai_target",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
wantMessage: "fallback group must be anthropic platform",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity_target",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
wantMessage: "fallback group must be anthropic platform",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subscription_group",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||||
|
wantMessage: "fallback group cannot be subscription type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested_fallback",
|
||||||
|
fallback: &Group{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||||
|
},
|
||||||
|
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
fallbackID := tc.fallback.ID
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: tc.fallback,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.wantMessage)
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "fallback group not found")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.created)
|
||||||
|
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||||
|
zero := int64(0)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &zero,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.created)
|
||||||
|
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
SubscriptionType: SubscriptionTypeSubscription,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
clear := int64(0)
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &clear,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|||||||
@@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p
|
|||||||
return s.listWithFiltersCodes, result, nil
|
return s.listWithFiltersCodes, result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserPaginated call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
|
||||||
|
panic("unexpected SumPositiveBalanceByUser call")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
repo := &accountRepoStubForAdminList{
|
repo := &accountRepoStubForAdminList{
|
||||||
|
|||||||
@@ -13,23 +13,34 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityStickySessionTTL = time.Hour
|
antigravityStickySessionTTL = time.Hour
|
||||||
antigravityMaxRetries = 3
|
antigravityDefaultMaxRetries = 3
|
||||||
antigravityRetryBaseDelay = 1 * time.Second
|
antigravityRetryBaseDelay = 1 * time.Second
|
||||||
antigravityRetryMaxDelay = 16 * time.Second
|
antigravityRetryMaxDelay = 16 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
|
const (
|
||||||
|
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
|
||||||
|
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
|
||||||
|
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
|
||||||
|
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
|
||||||
|
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
|
||||||
|
antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
|
||||||
|
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||||
|
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||||
|
)
|
||||||
|
|
||||||
// antigravityRetryLoopParams 重试循环的参数
|
// antigravityRetryLoopParams 重试循环的参数
|
||||||
type antigravityRetryLoopParams struct {
|
type antigravityRetryLoopParams struct {
|
||||||
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
|
|||||||
action string
|
action string
|
||||||
body []byte
|
body []byte
|
||||||
quotaScope AntigravityQuotaScope
|
quotaScope AntigravityQuotaScope
|
||||||
|
maxRetries int
|
||||||
c *gin.Context
|
c *gin.Context
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
|
|||||||
resp *http.Response
|
resp *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||||
|
type PromptTooLongError struct {
|
||||||
|
StatusCode int
|
||||||
|
RequestID string
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PromptTooLongError) Error() string {
|
||||||
|
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||||
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
baseURLs := antigravity.ForwardBaseURLs()
|
||||||
|
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs)
|
||||||
if len(availableURLs) == 0 {
|
if len(availableURLs) == 0 {
|
||||||
availableURLs = antigravity.BaseURLs
|
availableURLs = baseURLs
|
||||||
|
}
|
||||||
|
|
||||||
|
maxRetries := p.maxRetries
|
||||||
|
if maxRetries <= 0 {
|
||||||
|
maxRetries = antigravityDefaultMaxRetries
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
|
|||||||
urlFallbackLoop:
|
urlFallbackLoop:
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
usedBaseURL = baseURL
|
usedBaseURL = baseURL
|
||||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||||||
@@ -109,8 +138,8 @@ urlFallbackLoop:
|
|||||||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||||
continue urlFallbackLoop
|
continue urlFallbackLoop
|
||||||
}
|
}
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < maxRetries {
|
||||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err)
|
||||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||||
return nil, p.ctx.Err()
|
return nil, p.ctx.Err()
|
||||||
@@ -134,7 +163,7 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 账户/模型配额限流,重试 3 次(指数退避)
|
// 账户/模型配额限流,重试 3 次(指数退避)
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < maxRetries {
|
||||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||||
@@ -147,7 +176,7 @@ urlFallbackLoop:
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: getUpstreamDetail(respBody),
|
Detail: getUpstreamDetail(respBody),
|
||||||
})
|
})
|
||||||
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200))
|
||||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||||
return nil, p.ctx.Err()
|
return nil, p.ctx.Err()
|
||||||
@@ -171,7 +200,7 @@ urlFallbackLoop:
|
|||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < maxRetries {
|
||||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||||
@@ -184,7 +213,7 @@ urlFallbackLoop:
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: getUpstreamDetail(respBody),
|
Detail: getUpstreamDetail(respBody),
|
||||||
})
|
})
|
||||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500))
|
||||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||||
return nil, p.ctx.Err()
|
return nil, p.ctx.Err()
|
||||||
@@ -273,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 直接支持的模型(精确匹配透传)
|
// Antigravity 直接支持的模型(精确匹配透传)
|
||||||
|
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
|
||||||
var antigravitySupportedModels = map[string]bool{
|
var antigravitySupportedModels = map[string]bool{
|
||||||
"claude-opus-4-5-thinking": true,
|
"claude-opus-4-5-thinking": true,
|
||||||
"claude-sonnet-4-5": true,
|
"claude-sonnet-4-5": true,
|
||||||
"claude-sonnet-4-5-thinking": true,
|
"claude-sonnet-4-5-thinking": true,
|
||||||
"gemini-2.5-flash": true,
|
|
||||||
"gemini-2.5-flash-lite": true,
|
|
||||||
"gemini-2.5-flash-thinking": true,
|
|
||||||
"gemini-3-flash": true,
|
"gemini-3-flash": true,
|
||||||
"gemini-3-pro-low": true,
|
"gemini-3-pro-low": true,
|
||||||
"gemini-3-pro-high": true,
|
"gemini-3-pro-high": true,
|
||||||
@@ -288,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
|
|||||||
|
|
||||||
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
||||||
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
||||||
|
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
|
||||||
var antigravityPrefixMapping = []struct {
|
var antigravityPrefixMapping = []struct {
|
||||||
prefix string
|
prefix string
|
||||||
target string
|
target string
|
||||||
}{
|
}{
|
||||||
// 长前缀优先
|
// gemini-2.5 → gemini-3 映射(长前缀优先)
|
||||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
|
||||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
|
||||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
|
||||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
|
||||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
|
||||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
|
||||||
|
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
|
||||||
|
// gemini-3 前缀映射
|
||||||
|
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||||
|
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||||
|
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||||
|
// Claude 映射
|
||||||
|
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||||
|
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||||
|
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||||
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
||||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||||
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
|
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
|
||||||
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||||
@@ -383,6 +419,11 @@ type TestConnectionResult struct {
|
|||||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||||
|
// 上游透传账号使用专用测试方法
|
||||||
|
if account.Type == AccountTypeUpstream {
|
||||||
|
return s.testUpstreamConnection(ctx, account, modelID)
|
||||||
|
}
|
||||||
|
|
||||||
// 获取 token
|
// 获取 token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
return nil, errors.New("antigravity token provider not configured")
|
return nil, errors.New("antigravity token provider not configured")
|
||||||
@@ -477,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testUpstreamConnection 测试上游透传账号连接
|
||||||
|
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||||
|
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||||
|
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||||
|
if baseURL == "" || apiKey == "" {
|
||||||
|
return nil, errors.New("upstream account missing base_url or api_key")
|
||||||
|
}
|
||||||
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
|
|
||||||
|
// 使用 Claude 模型进行测试
|
||||||
|
if modelID == "" {
|
||||||
|
modelID = "claude-sonnet-4-20250514"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建最小测试请求
|
||||||
|
testReq := map[string]any{
|
||||||
|
"model": modelID,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{"role": "user", "content": "."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestBody, err := json.Marshal(testReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 HTTP 请求
|
||||||
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
req.Header.Set("x-api-key", apiKey)
|
||||||
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
|
||||||
|
// 代理 URL
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL)
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("请求失败: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取响应文本
|
||||||
|
var respData map[string]any
|
||||||
|
text := ""
|
||||||
|
if json.Unmarshal(respBody, &respData) == nil {
|
||||||
|
if content, ok := respData["content"].([]any); ok && len(content) > 0 {
|
||||||
|
if block, ok := content[0].(map[string]any); ok {
|
||||||
|
if t, ok := block["text"].(string); ok {
|
||||||
|
text = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TestConnectionResult{
|
||||||
|
Text: text,
|
||||||
|
MappedModel: modelID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||||
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
|
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
|
||||||
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
|
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
|
||||||
@@ -527,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
|
|||||||
}
|
}
|
||||||
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
||||||
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
||||||
|
|
||||||
|
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil {
|
||||||
|
opts.EnableMCPXML = group.MCPXMLInject
|
||||||
|
}
|
||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
|
|||||||
|
|
||||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||||
|
// 上游透传账号直接转发,不走 OAuth token 刷新
|
||||||
|
if account.Type == AccountTypeUpstream {
|
||||||
|
return s.ForwardUpstream(ctx, c, account, body)
|
||||||
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
sessionID := getSessionID(c)
|
sessionID := getSessionID(c)
|
||||||
prefix := logPrefix(sessionID, account.Name)
|
prefix := logPrefix(sessionID, account.Name)
|
||||||
@@ -711,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
originalModel := claudeReq.Model
|
originalModel := claudeReq.Model
|
||||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||||
|
billingModel := originalModel
|
||||||
|
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
|
||||||
|
billingModel = mappedModel
|
||||||
|
}
|
||||||
|
afterSwitch := antigravityHasAccountSwitch(ctx)
|
||||||
|
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -759,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
httpUpstream: s.httpUpstream,
|
httpUpstream: s.httpUpstream,
|
||||||
settingService: s.settingService,
|
settingService: s.settingService,
|
||||||
handleError: s.handleUpstreamError,
|
handleError: s.handleUpstreamError,
|
||||||
|
maxRetries: maxRetries,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||||
@@ -835,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
httpUpstream: s.httpUpstream,
|
httpUpstream: s.httpUpstream,
|
||||||
settingService: s.settingService,
|
settingService: s.settingService,
|
||||||
handleError: s.handleUpstreamError,
|
handleError: s.handleUpstreamError,
|
||||||
|
maxRetries: maxRetries,
|
||||||
})
|
})
|
||||||
if retryErr != nil {
|
if retryErr != nil {
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
@@ -910,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
// 处理错误响应(重试后仍失败或不触发重试)
|
// 处理错误响应(重试后仍失败或不触发重试)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "prompt_too_long",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &PromptTooLongError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Body: respBody,
|
||||||
|
}
|
||||||
|
}
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
@@ -971,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel, // 使用原始模型用于计费和日志
|
Model: billingModel, // 计费模型(可按映射模型覆盖)
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -996,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect thinking block modification errors:
|
||||||
|
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
|
||||||
|
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPromptTooLongError(respBody []byte) bool {
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||||
|
if msg == "" {
|
||||||
|
msg = strings.ToLower(string(respBody))
|
||||||
|
}
|
||||||
|
return strings.Contains(msg, "prompt is too long")
|
||||||
|
}
|
||||||
|
|
||||||
func extractAntigravityErrorMessage(body []byte) string {
|
func extractAntigravityErrorMessage(body []byte) string {
|
||||||
var payload map[string]any
|
var payload map[string]any
|
||||||
if err := json.Unmarshal(body, &payload); err != nil {
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parseNestedMessage := func(msg string) string {
|
||||||
|
trimmed := strings.TrimSpace(msg)
|
||||||
|
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var nested map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if errObj, ok := nested["error"].(map[string]any); ok {
|
||||||
|
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Google-style: {"error": {"message": "..."}}
|
// Google-style: {"error": {"message": "..."}}
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||||
|
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: top-level message
|
// Fallback: top-level message
|
||||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||||
|
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1241,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
|
|||||||
return changed, nil
|
return changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForwardUpstream 透传请求到上游 Antigravity 服务
|
||||||
|
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
|
||||||
|
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
sessionID := getSessionID(c)
|
||||||
|
prefix := logPrefix(sessionID, account.Name)
|
||||||
|
|
||||||
|
// 获取上游配置
|
||||||
|
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||||
|
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||||
|
if baseURL == "" || apiKey == "" {
|
||||||
|
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||||
|
}
|
||||||
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
|
|
||||||
|
// 解析请求获取模型信息
|
||||||
|
var claudeReq antigravity.ClaudeRequest
|
||||||
|
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||||
|
return nil, fmt.Errorf("missing model")
|
||||||
|
}
|
||||||
|
originalModel := claudeReq.Model
|
||||||
|
billingModel := originalModel
|
||||||
|
|
||||||
|
// 构建上游请求 URL
|
||||||
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
|
|
||||||
|
// 创建请求
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求头
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||||
|
|
||||||
|
// 透传 Claude 相关 headers
|
||||||
|
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||||
|
req.Header.Set("anthropic-version", v)
|
||||||
|
}
|
||||||
|
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||||
|
req.Header.Set("anthropic-beta", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 代理 URL
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 处理错误响应
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
// 429 错误时标记账号限流
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 透传上游错误
|
||||||
|
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
Model: billingModel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理成功响应(流式/非流式)
|
||||||
|
var usage *ClaudeUsage
|
||||||
|
var firstTokenMs *int
|
||||||
|
|
||||||
|
if claudeReq.Stream {
|
||||||
|
// 流式响应:透传
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
|
||||||
|
} else {
|
||||||
|
// 非流式响应:直接透传
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 usage
|
||||||
|
usage = s.extractClaudeUsage(respBody)
|
||||||
|
|
||||||
|
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建计费结果
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
Model: billingModel,
|
||||||
|
Stream: claudeReq.Stream,
|
||||||
|
Duration: duration,
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: usage.InputTokens,
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||||
|
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamUpstreamResponse 透传上游流式响应并提取 usage
|
||||||
|
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var firstTokenMs *int
|
||||||
|
var firstTokenRecorded bool
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
buf := make([]byte, 0, 64*1024)
|
||||||
|
scanner.Buffer(buf, 1024*1024)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
|
||||||
|
// 记录首 token 时间
|
||||||
|
if !firstTokenRecorded && len(line) > 0 {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
firstTokenRecorded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||||
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
dataStr := bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
var event map[string]any
|
||||||
|
if json.Unmarshal(dataStr, &event) == nil {
|
||||||
|
if u, ok := event["usage"].(map[string]any); ok {
|
||||||
|
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.InputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.OutputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.CacheReadInputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.CacheCreationInputTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 透传行
|
||||||
|
_, _ = c.Writer.Write(line)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage, firstTokenMs
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||||
|
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var resp map[string]any
|
||||||
|
if json.Unmarshal(body, &resp) != nil {
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
if u, ok := resp["usage"].(map[string]any); ok {
|
||||||
|
if v, ok := u["input_tokens"].(float64); ok {
|
||||||
|
usage.InputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["output_tokens"].(float64); ok {
|
||||||
|
usage.OutputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheReadInputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreationInputTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
// ForwardGemini 转发 Gemini 协议请求
|
// ForwardGemini 转发 Gemini 协议请求
|
||||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -1280,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
mappedModel := s.getMappedModel(account, originalModel)
|
mappedModel := s.getMappedModel(account, originalModel)
|
||||||
|
billingModel := originalModel
|
||||||
|
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
|
||||||
|
billingModel = mappedModel
|
||||||
|
}
|
||||||
|
afterSwitch := antigravityHasAccountSwitch(ctx)
|
||||||
|
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -1299,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
|
||||||
|
filteredBody, err := filterEmptyPartsFromGeminiRequest(body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Antigravity] Failed to filter empty parts: %v", err)
|
||||||
|
filteredBody = body
|
||||||
|
}
|
||||||
|
|
||||||
// Antigravity 上游要求必须包含身份提示词,注入到请求中
|
// Antigravity 上游要求必须包含身份提示词,注入到请求中
|
||||||
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
|
injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1337,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
httpUpstream: s.httpUpstream,
|
httpUpstream: s.httpUpstream,
|
||||||
settingService: s.settingService,
|
settingService: s.settingService,
|
||||||
handleError: s.handleUpstreamError,
|
handleError: s.handleUpstreamError,
|
||||||
|
maxRetries: maxRetries,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||||
@@ -1486,7 +1914,7 @@ handleSuccess:
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: billingModel,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -1530,9 +1958,88 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
|||||||
|
|
||||||
func antigravityUseScopeRateLimit() bool {
|
func antigravityUseScopeRateLimit() bool {
|
||||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
|
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
|
||||||
|
// 默认开启按配额域限流,只有明确设置为禁用值时才关闭
|
||||||
|
if v == "0" || v == "false" || v == "no" || v == "off" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityHasAccountSwitch(ctx context.Context) bool {
|
||||||
|
if ctx == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
|
||||||
|
return v > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityMaxRetries() int {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
|
||||||
|
if raw == "" {
|
||||||
|
return antigravityDefaultMaxRetries
|
||||||
|
}
|
||||||
|
value, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || value <= 0 {
|
||||||
|
return antigravityDefaultMaxRetries
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityMaxRetriesAfterSwitch() int {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
|
||||||
|
if raw == "" {
|
||||||
|
return antigravityMaxRetries()
|
||||||
|
}
|
||||||
|
value, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || value <= 0 {
|
||||||
|
return antigravityMaxRetries()
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
|
||||||
|
// 优先使用模型细分配置,未设置则回退到平台级配置
|
||||||
|
func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
|
||||||
|
var envKey string
|
||||||
|
if strings.HasPrefix(model, "claude-") {
|
||||||
|
envKey = antigravityMaxRetriesClaudeEnv
|
||||||
|
} else if isImageGenerationModel(model) {
|
||||||
|
envKey = antigravityMaxRetriesGeminiImageEnv
|
||||||
|
} else if strings.HasPrefix(model, "gemini-") {
|
||||||
|
envKey = antigravityMaxRetriesGeminiTextEnv
|
||||||
|
}
|
||||||
|
|
||||||
|
if envKey != "" {
|
||||||
|
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
|
||||||
|
if value, err := strconv.Atoi(raw); err == nil && value > 0 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if afterSwitch {
|
||||||
|
return antigravityMaxRetriesAfterSwitch()
|
||||||
|
}
|
||||||
|
return antigravityMaxRetries()
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityUseMappedModelForBilling() bool {
|
||||||
|
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv)))
|
||||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
|
||||||
|
if raw == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
seconds, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || seconds <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second, true
|
||||||
|
}
|
||||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||||
if statusCode == 429 {
|
if statusCode == 429 {
|
||||||
@@ -1545,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
|||||||
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
|
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
|
||||||
}
|
}
|
||||||
defaultDur := time.Duration(fallbackMinutes) * time.Minute
|
defaultDur := time.Duration(fallbackMinutes) * time.Minute
|
||||||
|
if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok {
|
||||||
|
defaultDur = fallbackDur
|
||||||
|
}
|
||||||
ra := time.Now().Add(defaultDur)
|
ra := time.Now().Add(defaultDur)
|
||||||
if useScopeLimit {
|
if useScopeLimit {
|
||||||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||||
@@ -2182,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
|||||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||||
|
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||||
statusStr := "UNKNOWN"
|
statusStr := "UNKNOWN"
|
||||||
switch status {
|
switch status {
|
||||||
@@ -2607,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
|
|||||||
|
|
||||||
return json.Marshal(payload)
|
return json.Marshal(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
|
||||||
|
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
|
||||||
|
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
contents, ok := payload["contents"].([]any)
|
||||||
|
if !ok || len(contents) == 0 {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]any, 0, len(contents))
|
||||||
|
modified := false
|
||||||
|
|
||||||
|
for _, c := range contents {
|
||||||
|
contentMap, ok := c.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
filtered = append(filtered, c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts, hasParts := contentMap["parts"]
|
||||||
|
if !hasParts {
|
||||||
|
filtered = append(filtered, c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
partsSlice, ok := parts.([]any)
|
||||||
|
if !ok {
|
||||||
|
filtered = append(filtered, c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过 parts 为空数组的消息
|
||||||
|
if len(partsSlice) == 0 {
|
||||||
|
modified = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = append(filtered, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload["contents"] = filtered
|
||||||
|
return json.Marshal(payload)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
|||||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPromptTooLongError(t *testing.T) {
|
||||||
|
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||||||
|
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||||||
|
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpUpstreamStub struct {
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return s.resp, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||||||
|
return s.resp, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-opus-4-5",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "acc-1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.Nil(t, result)
|
||||||
|
|
||||||
|
var promptErr *PromptTooLongError
|
||||||
|
require.ErrorAs(t, err, &promptErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||||||
|
require.Equal(t, "req-1", promptErr.RequestID)
|
||||||
|
require.NotEmpty(t, promptErr.Body)
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
|
||||||
|
t.Setenv(antigravityMaxRetriesEnv, "4")
|
||||||
|
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
|
||||||
|
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
|
||||||
|
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
|
||||||
|
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
|
||||||
|
|
||||||
|
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
|
||||||
|
require.Equal(t, 4, got)
|
||||||
|
|
||||||
|
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
|
||||||
|
require.Equal(t, 7, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
|
||||||
|
t.Setenv(antigravityMaxRetriesEnv, "5")
|
||||||
|
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
|
||||||
|
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
|
||||||
|
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
|
||||||
|
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
|
||||||
|
|
||||||
|
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
|
||||||
|
require.Equal(t, 5, got)
|
||||||
|
}
|
||||||
|
|||||||
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
|||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-5",
|
||||||
},
|
},
|
||||||
|
|
||||||
// 3. Gemini 透传
|
// 3. Gemini 2.5 → 3 映射
|
||||||
{
|
{
|
||||||
name: "Gemini透传 - gemini-2.5-flash",
|
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
|
||||||
requestedModel: "gemini-2.5-flash",
|
requestedModel: "gemini-2.5-flash",
|
||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "gemini-2.5-flash",
|
expected: "gemini-3-flash",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Gemini透传 - gemini-2.5-pro",
|
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
|
||||||
requestedModel: "gemini-2.5-pro",
|
requestedModel: "gemini-2.5-pro",
|
||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "gemini-2.5-pro",
|
expected: "gemini-3-pro-high",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Gemini透传 - gemini-future-model",
|
name: "Gemini透传 - gemini-future-model",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -16,6 +17,21 @@ const (
|
|||||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
|
||||||
|
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
|
||||||
|
if len(supportedScopes) == 0 {
|
||||||
|
// 未配置时默认全部支持
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
supported := slices.Contains(supportedScopes, string(scope))
|
||||||
|
return supported
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
|
||||||
|
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||||
|
return resolveAntigravityQuotaScope(requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||||
model := normalizeAntigravityModelName(requestedModel)
|
model := normalizeAntigravityModelName(requestedModel)
|
||||||
@@ -89,3 +105,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
|
|||||||
}
|
}
|
||||||
return &resetAt
|
return &resetAt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var antigravityAllScopes = []AntigravityQuotaScope{
|
||||||
|
AntigravityQuotaScopeClaude,
|
||||||
|
AntigravityQuotaScopeGeminiText,
|
||||||
|
AntigravityQuotaScopeGeminiImage,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
|
||||||
|
if a == nil || a.Platform != PlatformAntigravity {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
result := make(map[string]int64)
|
||||||
|
for _, scope := range antigravityAllScopes {
|
||||||
|
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||||
|
if resetAt != nil && now.Before(*resetAt) {
|
||||||
|
remainingSec := int64(time.Until(*resetAt).Seconds())
|
||||||
|
if remainingSec > 0 {
|
||||||
|
result[string(scope)] = remainingSec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,14 @@ package service
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// API Key status constants
|
||||||
|
const (
|
||||||
|
StatusAPIKeyActive = "active"
|
||||||
|
StatusAPIKeyDisabled = "disabled"
|
||||||
|
StatusAPIKeyQuotaExhausted = "quota_exhausted"
|
||||||
|
StatusAPIKeyExpired = "expired"
|
||||||
|
)
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID int64
|
ID int64
|
||||||
UserID int64
|
UserID int64
|
||||||
@@ -15,8 +23,53 @@ type APIKey struct {
|
|||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
User *User
|
User *User
|
||||||
Group *Group
|
Group *Group
|
||||||
|
|
||||||
|
// Quota fields
|
||||||
|
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||||
|
QuotaUsed float64 // Used quota amount
|
||||||
|
ExpiresAt *time.Time // Expiration time (nil = never expires)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *APIKey) IsActive() bool {
|
func (k *APIKey) IsActive() bool {
|
||||||
return k.Status == StatusActive
|
return k.Status == StatusActive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsExpired checks if the API key has expired
|
||||||
|
func (k *APIKey) IsExpired() bool {
|
||||||
|
if k.ExpiresAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().After(*k.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsQuotaExhausted checks if the API key quota is exhausted
|
||||||
|
func (k *APIKey) IsQuotaExhausted() bool {
|
||||||
|
if k.Quota <= 0 {
|
||||||
|
return false // unlimited
|
||||||
|
}
|
||||||
|
return k.QuotaUsed >= k.Quota
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
|
||||||
|
func (k *APIKey) GetQuotaRemaining() float64 {
|
||||||
|
if k.Quota <= 0 {
|
||||||
|
return -1 // unlimited
|
||||||
|
}
|
||||||
|
remaining := k.Quota - k.QuotaUsed
|
||||||
|
if remaining < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
|
||||||
|
func (k *APIKey) GetDaysUntilExpiry() int {
|
||||||
|
if k.ExpiresAt == nil {
|
||||||
|
return -1 // never expires
|
||||||
|
}
|
||||||
|
duration := time.Until(*k.ExpiresAt)
|
||||||
|
if duration < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int(duration.Hours() / 24)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||||
type APIKeyAuthSnapshot struct {
|
type APIKeyAuthSnapshot struct {
|
||||||
APIKeyID int64 `json:"api_key_id"`
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
|
|||||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
User APIKeyAuthUserSnapshot `json:"user"`
|
User APIKeyAuthUserSnapshot `json:"user"`
|
||||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||||
|
|
||||||
|
// Quota fields for API Key independent quota feature
|
||||||
|
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||||
|
QuotaUsed float64 `json:"quota_used"` // Used quota amount
|
||||||
|
|
||||||
|
// Expiration field for API Key expiration feature
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthUserSnapshot 用户快照
|
// APIKeyAuthUserSnapshot 用户快照
|
||||||
@@ -23,25 +32,30 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
type APIKeyAuthGroupSnapshot struct {
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SubscriptionType string `json:"subscription_type"`
|
SubscriptionType string `json:"subscription_type"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
|
|
||||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
// Only anthropic groups use these fields; others may leave them empty.
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
|
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
|||||||
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
Status: apiKey.Status,
|
Status: apiKey.Status,
|
||||||
IPWhitelist: apiKey.IPWhitelist,
|
IPWhitelist: apiKey.IPWhitelist,
|
||||||
IPBlacklist: apiKey.IPBlacklist,
|
IPBlacklist: apiKey.IPBlacklist,
|
||||||
|
Quota: apiKey.Quota,
|
||||||
|
QuotaUsed: apiKey.QuotaUsed,
|
||||||
|
ExpiresAt: apiKey.ExpiresAt,
|
||||||
User: APIKeyAuthUserSnapshot{
|
User: APIKeyAuthUserSnapshot{
|
||||||
ID: apiKey.User.ID,
|
ID: apiKey.User.ID,
|
||||||
Status: apiKey.User.Status,
|
Status: apiKey.User.Status,
|
||||||
@@ -223,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
}
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
ID: apiKey.Group.ID,
|
ID: apiKey.Group.ID,
|
||||||
Name: apiKey.Group.Name,
|
Name: apiKey.Group.Name,
|
||||||
Platform: apiKey.Group.Platform,
|
Platform: apiKey.Group.Platform,
|
||||||
Status: apiKey.Group.Status,
|
Status: apiKey.Group.Status,
|
||||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
ModelRouting: apiKey.Group.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
|
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||||
|
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return snapshot
|
return snapshot
|
||||||
@@ -256,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
Status: snapshot.Status,
|
Status: snapshot.Status,
|
||||||
IPWhitelist: snapshot.IPWhitelist,
|
IPWhitelist: snapshot.IPWhitelist,
|
||||||
IPBlacklist: snapshot.IPBlacklist,
|
IPBlacklist: snapshot.IPBlacklist,
|
||||||
|
Quota: snapshot.Quota,
|
||||||
|
QuotaUsed: snapshot.QuotaUsed,
|
||||||
|
ExpiresAt: snapshot.ExpiresAt,
|
||||||
User: &User{
|
User: &User{
|
||||||
ID: snapshot.User.ID,
|
ID: snapshot.User.ID,
|
||||||
Status: snapshot.User.Status,
|
Status: snapshot.User.Status,
|
||||||
@@ -266,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
apiKey.Group = &Group{
|
apiKey.Group = &Group{
|
||||||
ID: snapshot.Group.ID,
|
ID: snapshot.Group.ID,
|
||||||
Name: snapshot.Group.Name,
|
Name: snapshot.Group.Name,
|
||||||
Platform: snapshot.Group.Platform,
|
Platform: snapshot.Group.Platform,
|
||||||
Status: snapshot.Group.Status,
|
Status: snapshot.Group.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
ModelRouting: snapshot.Group.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
|
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||||
|
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return apiKey
|
return apiKey
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ var (
|
|||||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||||
|
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
|
||||||
|
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
|
||||||
|
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
|
||||||
|
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
|
|||||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||||
|
|
||||||
|
// Quota methods
|
||||||
|
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
|
|||||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
|
|
||||||
|
// Quota fields
|
||||||
|
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||||
|
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAPIKeyRequest 更新API Key请求
|
// UpdateAPIKeyRequest 更新API Key请求
|
||||||
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
|
|||||||
Status *string `json:"status"`
|
Status *string `json:"status"`
|
||||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||||
|
|
||||||
|
// Quota fields
|
||||||
|
Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited)
|
||||||
|
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
|
||||||
|
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
|
||||||
|
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyService API Key服务
|
// APIKeyService API Key服务
|
||||||
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
IPWhitelist: req.IPWhitelist,
|
IPWhitelist: req.IPWhitelist,
|
||||||
IPBlacklist: req.IPBlacklist,
|
IPBlacklist: req.IPBlacklist,
|
||||||
|
Quota: req.Quota,
|
||||||
|
QuotaUsed: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set expiration time if specified
|
||||||
|
if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 {
|
||||||
|
expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays)
|
||||||
|
apiKey.ExpiresAt = &expiresAt
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||||
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update quota fields
|
||||||
|
if req.Quota != nil {
|
||||||
|
apiKey.Quota = *req.Quota
|
||||||
|
// If quota is increased and status was quota_exhausted, reactivate
|
||||||
|
if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed {
|
||||||
|
apiKey.Status = StatusActive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.ResetQuota != nil && *req.ResetQuota {
|
||||||
|
apiKey.QuotaUsed = 0
|
||||||
|
// If resetting quota and status was quota_exhausted, reactivate
|
||||||
|
if apiKey.Status == StatusAPIKeyQuotaExhausted {
|
||||||
|
apiKey.Status = StatusActive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.ClearExpiration {
|
||||||
|
apiKey.ExpiresAt = nil
|
||||||
|
// If clearing expiry and status was expired, reactivate
|
||||||
|
if apiKey.Status == StatusAPIKeyExpired {
|
||||||
|
apiKey.Status = StatusActive
|
||||||
|
}
|
||||||
|
} else if req.ExpiresAt != nil {
|
||||||
|
apiKey.ExpiresAt = req.ExpiresAt
|
||||||
|
// If extending expiry and status was expired, reactivate
|
||||||
|
if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) {
|
||||||
|
apiKey.Status = StatusActive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 更新 IP 限制(空数组会清空设置)
|
// 更新 IP 限制(空数组会清空设置)
|
||||||
apiKey.IPWhitelist = req.IPWhitelist
|
apiKey.IPWhitelist = req.IPWhitelist
|
||||||
apiKey.IPBlacklist = req.IPBlacklist
|
apiKey.IPBlacklist = req.IPBlacklist
|
||||||
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
|
|||||||
}
|
}
|
||||||
return keys, nil
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||||
|
// Returns nil if valid, error if invalid
|
||||||
|
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||||
|
// Check expiration
|
||||||
|
if apiKey.IsExpired() {
|
||||||
|
return ErrAPIKeyExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check quota
|
||||||
|
if apiKey.IsQuotaExhausted() {
|
||||||
|
return ErrAPIKeyQuotaExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaUsed updates the quota_used field after a request
|
||||||
|
// Also checks if quota is exhausted and updates status accordingly
|
||||||
|
func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use repository to atomically increment quota_used
|
||||||
|
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("increment quota used: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if quota is now exhausted and update status if needed
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil // Don't fail the request, just log
|
||||||
|
}
|
||||||
|
|
||||||
|
// If quota is set and now exhausted, update status
|
||||||
|
if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota {
|
||||||
|
apiKey.Status = StatusAPIKeyQuotaExhausted
|
||||||
|
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||||
|
return nil // Don't fail the request
|
||||||
|
}
|
||||||
|
// Invalidate cache so next request sees the new status
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
|
|||||||
return s.listKeysByGroupID(ctx, groupID)
|
return s.listKeysByGroupID(ctx, groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
panic("unexpected IncrementQuotaUsed call")
|
||||||
|
}
|
||||||
|
|
||||||
type authCacheStub struct {
|
type authCacheStub struct {
|
||||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
setAuthKeys []string
|
setAuthKeys []string
|
||||||
|
|||||||
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
|
|||||||
panic("unexpected ListKeysByGroupID call")
|
panic("unexpected ListKeysByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
|
panic("unexpected IncrementQuotaUsed call")
|
||||||
|
}
|
||||||
|
|
||||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -19,17 +19,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||||
|
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||||
|
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||||
@@ -47,6 +49,7 @@ type JWTClaims struct {
|
|||||||
// AuthService 认证服务
|
// AuthService 认证服务
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
|
redeemRepo RedeemCodeRepository
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
emailService *EmailService
|
emailService *EmailService
|
||||||
@@ -58,6 +61,7 @@ type AuthService struct {
|
|||||||
// NewAuthService 创建认证服务实例
|
// NewAuthService 创建认证服务实例
|
||||||
func NewAuthService(
|
func NewAuthService(
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
|
redeemRepo RedeemCodeRepository,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
emailService *EmailService,
|
emailService *EmailService,
|
||||||
@@ -67,6 +71,7 @@ func NewAuthService(
|
|||||||
) *AuthService {
|
) *AuthService {
|
||||||
return &AuthService{
|
return &AuthService{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
|
redeemRepo: redeemRepo,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
@@ -78,11 +83,11 @@ func NewAuthService(
|
|||||||
|
|
||||||
// Register 用户注册,返回token和用户
|
// Register 用户注册,返回token和用户
|
||||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
|
||||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
|
||||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
return "", nil, ErrRegDisabled
|
return "", nil, ErrRegDisabled
|
||||||
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
return "", nil, ErrEmailReserved
|
return "", nil, ErrEmailReserved
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否需要邀请码
|
||||||
|
var invitationRedeemCode *RedeemCode
|
||||||
|
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||||
|
if invitationCode == "" {
|
||||||
|
return "", nil, ErrInvitationCodeRequired
|
||||||
|
}
|
||||||
|
// 验证邀请码
|
||||||
|
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
|
||||||
|
return "", nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
// 检查类型和状态
|
||||||
|
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||||
|
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||||
|
return "", nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
invitationRedeemCode = redeemCode
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否需要邮件验证
|
// 检查是否需要邮件验证
|
||||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||||
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
return "", nil, ErrServiceUnavailable
|
return "", nil, ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 标记邀请码为已使用(如果使用了邀请码)
|
||||||
|
if invitationRedeemCode != nil {
|
||||||
|
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||||
|
// 邀请码标记失败不影响注册,只记录日志
|
||||||
|
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
// 应用优惠码(如果提供且功能已启用)
|
// 应用优惠码(如果提供且功能已启用)
|
||||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
|||||||
|
|
||||||
return NewAuthService(
|
return NewAuthService(
|
||||||
repo,
|
repo,
|
||||||
|
nil, // redeemRepo
|
||||||
cfg,
|
cfg,
|
||||||
settingService,
|
settingService,
|
||||||
emailService,
|
emailService,
|
||||||
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
|||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
// 应返回服务不可用错误,而不是允许绕过验证
|
// 应返回服务不可用错误,而不是允许绕过验证
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
||||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
|||||||
SettingKeyEmailVerifyEnabled: "true",
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
}, cache)
|
}, cache)
|
||||||
|
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
||||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
|||||||
SettingKeyEmailVerifyEnabled: "true",
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
}, cache)
|
}, cache)
|
||||||
|
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
||||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||||
require.ErrorContains(t, err, "verify code")
|
require.ErrorContains(t, err, "verify code")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
|
|||||||
return s.CalculateCost(model, tokens, multiplier)
|
return s.CalculateCost(model, tokens, multiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||||
|
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||||
|
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||||
|
//
|
||||||
|
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
|
||||||
|
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
|
||||||
|
// 范围内正常计费,范围外 × 2 计费
|
||||||
|
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||||
|
// 未启用长上下文计费,直接走正常计费
|
||||||
|
if threshold <= 0 || extraMultiplier <= 1 {
|
||||||
|
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总输入 token(缓存读取 + 新输入)
|
||||||
|
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||||
|
if total <= threshold {
|
||||||
|
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 拆分成范围内和范围外
|
||||||
|
var inRangeCacheTokens, inRangeInputTokens int
|
||||||
|
var outRangeCacheTokens, outRangeInputTokens int
|
||||||
|
|
||||||
|
if tokens.CacheReadTokens >= threshold {
|
||||||
|
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
|
||||||
|
inRangeCacheTokens = threshold
|
||||||
|
inRangeInputTokens = 0
|
||||||
|
outRangeCacheTokens = tokens.CacheReadTokens - threshold
|
||||||
|
outRangeInputTokens = tokens.InputTokens
|
||||||
|
} else {
|
||||||
|
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
|
||||||
|
inRangeCacheTokens = tokens.CacheReadTokens
|
||||||
|
inRangeInputTokens = threshold - tokens.CacheReadTokens
|
||||||
|
outRangeCacheTokens = 0
|
||||||
|
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// 范围内部分:正常计费
|
||||||
|
inRangeTokens := UsageTokens{
|
||||||
|
InputTokens: inRangeInputTokens,
|
||||||
|
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||||
|
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||||
|
CacheReadTokens: inRangeCacheTokens,
|
||||||
|
}
|
||||||
|
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 范围外部分:× extraMultiplier 计费
|
||||||
|
outRangeTokens := UsageTokens{
|
||||||
|
InputTokens: outRangeInputTokens,
|
||||||
|
CacheReadTokens: outRangeCacheTokens,
|
||||||
|
}
|
||||||
|
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||||
|
if err != nil {
|
||||||
|
return inRangeCost, nil // 出错时返回范围内成本
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并成本
|
||||||
|
return &CostBreakdown{
|
||||||
|
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||||
|
OutputCost: inRangeCost.OutputCost,
|
||||||
|
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||||
|
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||||
|
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||||
|
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||||
func (s *BillingService) ListSupportedModels() []string {
|
func (s *BillingService) ListSupportedModels() []string {
|
||||||
models := make([]string, 0)
|
models := make([]string, 0)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||||
|
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -38,6 +39,7 @@ const (
|
|||||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||||
|
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromoCode status constants
|
// PromoCode status constants
|
||||||
@@ -71,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
|||||||
// Setting keys
|
// Setting keys
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||||
|
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||||
|
|||||||
23
backend/internal/service/gateway_beta_test.go
Normal file
23
backend/internal/service/gateway_beta_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeAnthropicBeta(t *testing.T) {
|
||||||
|
got := mergeAnthropicBeta(
|
||||||
|
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||||
|
"foo, oauth-2025-04-20,bar, foo",
|
||||||
|
)
|
||||||
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||||
|
got := mergeAnthropicBeta(
|
||||||
|
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||||
|
}
|
||||||
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ptr[T any](v T) *T {
|
func ptr[T any](v T) *T {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|||||||
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Stream: true,
|
||||||
|
MetadataUserID: "",
|
||||||
|
System: nil,
|
||||||
|
Messages: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
|
||||||
|
|
||||||
|
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
|
||||||
|
require.NotEmpty(t, got)
|
||||||
|
|
||||||
|
// Legacy format: user_{client}_account__session_{uuid}
|
||||||
|
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
|
||||||
|
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Stream: true,
|
||||||
|
MetadataUserID: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"account_uuid": "acc-uuid",
|
||||||
|
"claude_user_id": "clientid123",
|
||||||
|
"anthropic_user_id": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
|
||||||
|
require.NotEmpty(t, got)
|
||||||
|
|
||||||
|
// New format: user_{client}_account_{account_uuid}_session_{uuid}
|
||||||
|
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
|
||||||
|
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||||
|
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
body string
|
body string
|
||||||
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
system: "Custom prompt",
|
system: "Custom prompt",
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Custom prompt",
|
wantSecondText: claudePrefix + "\n\nCustom prompt",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "string system equals Claude Code prompt",
|
name: "string system equals Claude Code prompt",
|
||||||
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
// Claude Code + Custom = 2
|
// Claude Code + Custom = 2
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Custom",
|
wantSecondText: claudePrefix + "\n\nCustom",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||||
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
// Claude Code at start + Other = 2 (deduped)
|
// Claude Code at start + Other = 2 (deduped)
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Other",
|
wantSecondText: claudePrefix + "\n\nOther",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty array",
|
name: "empty array",
|
||||||
|
|||||||
21
backend/internal/service/gateway_sanitize_test.go
Normal file
21
backend/internal/service/gateway_sanitize_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||||
|
in := "You are OpenCode, the best coding agent on the planet."
|
||||||
|
got := sanitizeSystemText(in)
|
||||||
|
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||||
|
in := "OpenCode and opencode are mentioned."
|
||||||
|
got := sanitizeToolDescription(in)
|
||||||
|
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||||
|
require.Equal(t, in, got)
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -36,6 +36,11 @@ const (
|
|||||||
geminiRetryMaxDelay = 16 * time.Second
|
geminiRetryMaxDelay = 16 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
|
||||||
|
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
|
||||||
|
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||||
|
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
|
||||||
|
|
||||||
type GeminiMessagesCompatService struct {
|
type GeminiMessagesCompatService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||||
}
|
}
|
||||||
|
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
|
||||||
originalClaudeBody := body
|
originalClaudeBody := body
|
||||||
|
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
@@ -971,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
|
||||||
|
if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil {
|
||||||
|
body = filteredBody
|
||||||
|
}
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case "generateContent", "streamGenerateContent", "countTokens":
|
case "generateContent", "streamGenerateContent", "countTokens":
|
||||||
// ok
|
// ok
|
||||||
@@ -978,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
|
||||||
|
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
|
||||||
|
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||||
|
|
||||||
mappedModel := originalModel
|
mappedModel := originalModel
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
mappedModel = account.GetMappedModel(originalModel)
|
mappedModel = account.GetMappedModel(originalModel)
|
||||||
@@ -2657,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
|
|||||||
return &ts
|
return &ts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
|
||||||
|
// Fast path: only run when functionCall is present.
|
||||||
|
if !bytes.Contains(body, []byte(`"functionCall"`)) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
contentsAny, ok := payload["contents"].([]any)
|
||||||
|
if !ok || len(contentsAny) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for _, c := range contentsAny {
|
||||||
|
cm, ok := c.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
partsAny, ok := cm["parts"].([]any)
|
||||||
|
if !ok || len(partsAny) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, p := range partsAny {
|
||||||
|
pm, ok := p.(map[string]any)
|
||||||
|
if !ok || pm == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ts, _ := pm["thoughtSignature"].(string)
|
||||||
|
if strings.TrimSpace(ts) == "" {
|
||||||
|
pm["thoughtSignature"] = geminiDummyThoughtSignature
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
func extractGeminiFinishReason(geminiResp map[string]any) string {
|
func extractGeminiFinishReason(geminiResp map[string]any) string {
|
||||||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||||||
if cand, ok := candidates[0].(map[string]any); ok {
|
if cand, ok := candidates[0].(map[string]any); ok {
|
||||||
@@ -2856,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
|||||||
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
|
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
|
||||||
toolUseIDToName[id] = name
|
toolUseIDToName[id] = name
|
||||||
}
|
}
|
||||||
|
signature, _ := bm["signature"].(string)
|
||||||
|
signature = strings.TrimSpace(signature)
|
||||||
|
if signature == "" {
|
||||||
|
signature = geminiDummyThoughtSignature
|
||||||
|
}
|
||||||
parts = append(parts, map[string]any{
|
parts = append(parts, map[string]any{
|
||||||
|
"thoughtSignature": signature,
|
||||||
"functionCall": map[string]any{
|
"functionCall": map[string]any{
|
||||||
"name": name,
|
"name": name,
|
||||||
"args": bm["input"],
|
"args": bm["input"],
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||||
|
claudeReq := map[string]any{
|
||||||
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "hi"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "ok"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_123",
|
||||||
|
"name": "default_api:write_file",
|
||||||
|
"input": map[string]any{"path": "a.txt", "content": "x"},
|
||||||
|
// no signature on purpose
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "default_api:write_file",
|
||||||
|
"description": "write file",
|
||||||
|
"input_schema": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{"path": map[string]any{"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(claudeReq)
|
||||||
|
|
||||||
|
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("convert failed: %v", err)
|
||||||
|
}
|
||||||
|
s := string(out)
|
||||||
|
if !strings.Contains(s, "\"functionCall\"") {
|
||||||
|
t.Fatalf("expected functionCall in output, got: %s", s)
|
||||||
|
}
|
||||||
|
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||||
|
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
|
||||||
|
geminiReq := map[string]any{
|
||||||
|
"contents": []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"parts": []any{
|
||||||
|
map[string]any{
|
||||||
|
"functionCall": map[string]any{
|
||||||
|
"name": "default_api:write_file",
|
||||||
|
"args": map[string]any{"path": "a.txt"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(geminiReq)
|
||||||
|
out := ensureGeminiFunctionCallThoughtSignatures(b)
|
||||||
|
s := string(out)
|
||||||
|
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||||
|
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||||
|
|
||||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||||
|
|||||||
@@ -2,20 +2,22 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中替换 thoughtSignature 字段为 dummy 签名,
|
||||||
// 以避免跨账号签名验证错误。
|
// 以避免跨账号签名验证错误。
|
||||||
//
|
//
|
||||||
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||||
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
// 会导致新账号的签名验证失败。通过替换为 dummy 签名,跳过签名验证。
|
||||||
//
|
//
|
||||||
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
// CleanGeminiNativeThoughtSignatures replaces thoughtSignature fields with dummy signature
|
||||||
// to avoid cross-account signature validation errors.
|
// in Gemini native API requests to avoid cross-account signature validation errors.
|
||||||
//
|
//
|
||||||
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||||
// thoughtSignatures from the old account will cause validation failures on the new account.
|
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||||
// By removing these signatures, we allow the new account to generate valid signatures.
|
// By replacing with dummy signature, we skip signature validation.
|
||||||
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return body
|
return body
|
||||||
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// 递归清理 thoughtSignature
|
// 递归替换 thoughtSignature 为 dummy 签名
|
||||||
cleaned := cleanThoughtSignaturesRecursive(data)
|
replaced := replaceThoughtSignaturesRecursive(data)
|
||||||
|
|
||||||
// 重新序列化
|
// 重新序列化
|
||||||
result, err := json.Marshal(cleaned)
|
result, err := json.Marshal(replaced)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 如果序列化失败,返回原始 body
|
// 如果序列化失败,返回原始 body
|
||||||
return body
|
return body
|
||||||
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
// replaceThoughtSignaturesRecursive 递归遍历数据结构,将所有 thoughtSignature 字段替换为 dummy 签名
|
||||||
func cleanThoughtSignaturesRecursive(data any) any {
|
func replaceThoughtSignaturesRecursive(data any) any {
|
||||||
switch v := data.(type) {
|
switch v := data.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
// 创建新的 map,移除 thoughtSignature
|
// 创建新的 map,替换 thoughtSignature 为 dummy 签名
|
||||||
result := make(map[string]any, len(v))
|
result := make(map[string]any, len(v))
|
||||||
for key, value := range v {
|
for key, value := range v {
|
||||||
// 跳过 thoughtSignature 字段
|
// 替换 thoughtSignature 字段为 dummy 签名
|
||||||
if key == "thoughtSignature" {
|
if key == "thoughtSignature" {
|
||||||
|
result[key] = antigravity.DummyThoughtSignature
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 递归处理嵌套结构
|
// 递归处理嵌套结构
|
||||||
result[key] = cleanThoughtSignaturesRecursive(value)
|
result[key] = replaceThoughtSignaturesRecursive(value)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
|
|||||||
// 递归处理数组中的每个元素
|
// 递归处理数组中的每个元素
|
||||||
result := make([]any, len(v))
|
result := make([]any, len(v))
|
||||||
for i, item := range v {
|
for i, item := range v {
|
||||||
result[i] = cleanThoughtSignaturesRecursive(item)
|
result[i] = replaceThoughtSignaturesRecursive(item)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ type Group struct {
|
|||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool
|
ClaudeCodeOnly bool
|
||||||
FallbackGroupID *int64
|
FallbackGroupID *int64
|
||||||
|
// 无效请求兜底分组(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
|
|
||||||
// 模型路由配置
|
// 模型路由配置
|
||||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||||
@@ -36,6 +38,13 @@ type Group struct {
|
|||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled bool
|
ModelRoutingEnabled bool
|
||||||
|
|
||||||
|
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
||||||
|
MCPXMLInject bool
|
||||||
|
|
||||||
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
|
// 可选值: claude, gemini_text, gemini_image
|
||||||
|
SupportedModelScopes []string
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ type GroupRepository interface {
|
|||||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
|
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||||
|
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||||
|
// BindAccountsToGroup 将多个账号绑定到指定分组
|
||||||
|
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroupRequest 创建分组请求
|
// CreateGroupRequest 创建分组请求
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ var (
|
|||||||
|
|
||||||
// 默认指纹值(当客户端未提供时使用)
|
// 默认指纹值(当客户端未提供时使用)
|
||||||
var defaultFingerprint = Fingerprint{
|
var defaultFingerprint = Fingerprint{
|
||||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||||
StainlessLang: "js",
|
StainlessLang: "js",
|
||||||
StainlessPackageVersion: "0.52.0",
|
StainlessPackageVersion: "0.70.0",
|
||||||
StainlessOS: "Linux",
|
StainlessOS: "Linux",
|
||||||
StainlessArch: "x64",
|
StainlessArch: "arm64",
|
||||||
StainlessRuntime: "node",
|
StainlessRuntime: "node",
|
||||||
StainlessRuntimeVersion: "v22.14.0",
|
StainlessRuntimeVersion: "v24.13.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fingerprint represents account fingerprint data
|
// Fingerprint represents account fingerprint data
|
||||||
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
|||||||
// RewriteUserID 重写body中的metadata.user_id
|
// RewriteUserID 重写body中的metadata.user_id
|
||||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
||||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
||||||
|
//
|
||||||
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
||||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析JSON
|
// 使用 RawMessage 保留其他字段的原始字节
|
||||||
var reqMap map[string]any
|
var reqMap map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
// 解析 metadata 字段
|
||||||
|
metadataRaw, ok := reqMap["metadata"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
userID, ok := metadata["user_id"].(string)
|
||||||
if !ok || userID == "" {
|
if !ok || userID == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
metadata["user_id"] = newUserID
|
||||||
reqMap["metadata"] = metadata
|
|
||||||
|
// 只重新序列化 metadata 字段
|
||||||
|
newMetadataRaw, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
reqMap["metadata"] = newMetadataRaw
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
return json.Marshal(reqMap)
|
||||||
}
|
}
|
||||||
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||||
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
||||||
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
||||||
|
//
|
||||||
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
||||||
// 先执行常规的 RewriteUserID 逻辑
|
// 先执行常规的 RewriteUserID 逻辑
|
||||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
||||||
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析重写后的 body,提取 user_id
|
// 使用 RawMessage 保留其他字段的原始字节
|
||||||
var reqMap map[string]any
|
var reqMap map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
// 解析 metadata 字段
|
||||||
|
metadataRaw, ok := reqMap["metadata"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||||
|
return newBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
userID, ok := metadata["user_id"].(string)
|
||||||
if !ok || userID == "" {
|
if !ok || userID == "" {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
)
|
)
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
metadata["user_id"] = newUserID
|
||||||
reqMap["metadata"] = metadata
|
|
||||||
|
// 只重新序列化 metadata 字段
|
||||||
|
newMetadataRaw, marshalErr := json.Marshal(metadata)
|
||||||
|
if marshalErr != nil {
|
||||||
|
return newBody, nil
|
||||||
|
}
|
||||||
|
reqMap["metadata"] = newMetadataRaw
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
return json.Marshal(reqMap)
|
||||||
}
|
}
|
||||||
@@ -327,7 +357,7 @@ func generateUUIDFromSeed(seed string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseUserAgentVersion 解析user-agent版本号
|
// parseUserAgentVersion 解析user-agent版本号
|
||||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||||
// 匹配 xxx/x.y.z 格式
|
// 匹配 xxx/x.y.z 格式
|
||||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
|
|||||||
LastChecked int64 `json:"lastChecked"`
|
LastChecked int64 `json:"lastChecked"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
|
||||||
result := codexTransformResult{}
|
result := codexTransformResult{}
|
||||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||||
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
result.PromptCacheKey = strings.TrimSpace(v)
|
result.PromptCacheKey = strings.TrimSpace(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
|
||||||
existingInstructions, _ := reqBody["instructions"].(string)
|
if applyInstructions(reqBody, isCodexCLI) {
|
||||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
result.Modified = true
|
||||||
|
|
||||||
if instructions != "" {
|
|
||||||
if existingInstructions != instructions {
|
|
||||||
reqBody["instructions"] = instructions
|
|
||||||
result.Modified = true
|
|
||||||
}
|
|
||||||
} else if existingInstructions == "" {
|
|
||||||
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
|
||||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
|
||||||
if codexInstructions != "" {
|
|
||||||
reqBody["instructions"] = codexInstructions
|
|
||||||
result.Modified = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||||
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
|
|||||||
return getCodexCLIInstructions()
|
return getCodexCLIInstructions()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyInstructions 处理 instructions 字段
|
||||||
|
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
|
||||||
|
// isCodexCLI=false: 优先使用 opencode 指令覆盖
|
||||||
|
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||||
|
if isCodexCLI {
|
||||||
|
return applyCodexCLIInstructions(reqBody)
|
||||||
|
}
|
||||||
|
return applyOpenCodeInstructions(reqBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
|
||||||
|
// 仅在 instructions 为空时添加 opencode 指令
|
||||||
|
func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||||
|
if !isInstructionsEmpty(reqBody) {
|
||||||
|
return false // 已有有效 instructions,不修改
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||||
|
if instructions != "" {
|
||||||
|
reqBody["instructions"] = instructions
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
|
||||||
|
// 优先使用 opencode 指令覆盖
|
||||||
|
func applyOpenCodeInstructions(reqBody map[string]any) bool {
|
||||||
|
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||||
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
|
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
if existingInstructions != instructions {
|
||||||
|
reqBody["instructions"] = instructions
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else if existingInstructions == "" {
|
||||||
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
if codexInstructions != "" {
|
||||||
|
reqBody["instructions"] = codexInstructions
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInstructionsEmpty 检查 instructions 字段是否为空
|
||||||
|
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
|
||||||
|
func isInstructionsEmpty(reqBody map[string]any) bool {
|
||||||
|
val, exists := reqBody["instructions"]
|
||||||
|
if !exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
str, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(str) == ""
|
||||||
|
}
|
||||||
|
|
||||||
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||||
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
|||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
// 未显式设置 store=true,默认为 false。
|
// 未显式设置 store=true,默认为 false。
|
||||||
store, ok := reqBody["store"].(bool)
|
store, ok := reqBody["store"].(bool)
|
||||||
@@ -59,7 +59,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
|||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
store, ok := reqBody["store"].(bool)
|
store, ok := reqBody["store"].(bool)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -79,7 +79,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
|||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
store, ok := reqBody["store"].(bool)
|
store, ok := reqBody["store"].(bool)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -97,7 +97,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
store, ok := reqBody["store"].(bool)
|
store, ok := reqBody["store"].(bool)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -148,7 +148,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
tools, ok := reqBody["tools"].([]any)
|
tools, ok := reqBody["tools"].([]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -169,7 +169,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
|||||||
"input": []any{},
|
"input": []any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCodexOAuthTransform(reqBody)
|
applyCodexOAuthTransform(reqBody, false)
|
||||||
|
|
||||||
input, ok := reqBody["input"].([]any)
|
input, ok := reqBody["input"].([]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -196,3 +196,77 @@ func setupCodexCache(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||||
|
// Codex CLI 场景:已有 instructions 时不修改
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"instructions": "existing instructions",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||||
|
|
||||||
|
instructions, ok := reqBody["instructions"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "existing instructions", instructions)
|
||||||
|
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
|
||||||
|
_ = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
|
||||||
|
// Codex CLI 场景:无 instructions 时补充默认值
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
// 没有 instructions 字段
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||||
|
|
||||||
|
instructions, ok := reqBody["instructions"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, instructions)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
|
||||||
|
// 非 Codex CLI 场景:使用 opencode 指令覆盖
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"instructions": "old instructions",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false
|
||||||
|
|
||||||
|
instructions, ok := reqBody["instructions"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEqual(t, "old instructions", instructions)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsInstructionsEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
reqBody map[string]any
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"missing field", map[string]any{}, true},
|
||||||
|
{"nil value", map[string]any{"instructions": nil}, true},
|
||||||
|
{"empty string", map[string]any{"instructions": ""}, true},
|
||||||
|
{"whitespace only", map[string]any{"instructions": " "}, true},
|
||||||
|
{"non-string", map[string]any{"instructions": 123}, true},
|
||||||
|
{"valid string", map[string]any{"instructions": "hello"}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isInstructionsEmpty(tt.reqBody)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -156,12 +156,15 @@ type OpenAIUsage struct {
|
|||||||
|
|
||||||
// OpenAIForwardResult represents the result of forwarding
|
// OpenAIForwardResult represents the result of forwarding
|
||||||
type OpenAIForwardResult struct {
|
type OpenAIForwardResult struct {
|
||||||
RequestID string
|
RequestID string
|
||||||
Usage OpenAIUsage
|
Usage OpenAIUsage
|
||||||
Model string
|
Model string
|
||||||
Stream bool
|
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||||
Duration time.Duration
|
// Stored for usage records display; nil means not provided / not applicable.
|
||||||
FirstTokenMs *int
|
ReasoningEffort *string
|
||||||
|
Stream bool
|
||||||
|
Duration time.Duration
|
||||||
|
FirstTokenMs *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||||
@@ -793,8 +796,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
if account.Type == AccountTypeOAuth {
|
||||||
codexResult := applyCodexOAuthTransform(reqBody)
|
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI)
|
||||||
if codexResult.Modified {
|
if codexResult.Modified {
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
@@ -842,6 +845,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
|
||||||
|
if _, has := reqBody["prompt_cache_retention"]; has {
|
||||||
|
delete(reqBody, "prompt_cache_retention")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-serialize body only if modified
|
// Re-serialize body only if modified
|
||||||
@@ -958,13 +967,16 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||||
|
|
||||||
return &OpenAIForwardResult{
|
return &OpenAIForwardResult{
|
||||||
RequestID: resp.Header.Get("x-request-id"),
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
Stream: reqStream,
|
ReasoningEffort: reasoningEffort,
|
||||||
Duration: time.Since(startTime),
|
Stream: reqStream,
|
||||||
FirstTokenMs: firstTokenMs,
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1260,15 +1272,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||||
|
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||||
|
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent {
|
if errorEventSent || clientDisconnected {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
payload := map[string]any{
|
||||||
flusher.Flush()
|
"type": "error",
|
||||||
|
"sequence_number": 0,
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": reason,
|
||||||
|
"code": reason,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if b, err := json.Marshal(payload); err == nil {
|
||||||
|
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
@@ -1280,6 +1306,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||||
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("Context canceled during streaming, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -1303,15 +1340,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
|
|
||||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||||
|
data = correctedData
|
||||||
line = "data: " + correctedData
|
line = "data: " + correctedData
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward line
|
// 写入客户端(客户端断开后继续 drain 上游)
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
// Record first token time
|
// Record first token time
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||||
@@ -1321,11 +1362,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
s.parseSSEUsage(data, usage)
|
s.parseSSEUsage(data, usage)
|
||||||
} else {
|
} else {
|
||||||
// Forward non-data lines as-is
|
// Forward non-data lines as-is
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -1333,6 +1377,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
@@ -1342,11 +1390,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
case <-keepaliveCh:
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
@@ -1628,13 +1681,14 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
|||||||
|
|
||||||
// OpenAIRecordUsageInput input for recording usage
|
// OpenAIRecordUsageInput input for recording usage
|
||||||
type OpenAIRecordUsageInput struct {
|
type OpenAIRecordUsageInput struct {
|
||||||
Result *OpenAIForwardResult
|
Result *OpenAIForwardResult
|
||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription
|
Subscription *UserSubscription
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
|
APIKeyService APIKeyQuotaUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage records usage and deducts balance
|
// RecordUsage records usage and deducts balance
|
||||||
@@ -1687,6 +1741,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: result.RequestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
InputTokens: actualInputTokens,
|
InputTokens: actualInputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
@@ -1745,6 +1800,13 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update API key quota if applicable (only for balance mode with quota set)
|
||||||
|
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||||
|
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||||
|
log.Printf("Update API key quota failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Schedule batch update for account last_used_at
|
// Schedule batch update for account last_used_at
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
|
|
||||||
@@ -1881,3 +1943,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
|
||||||
|
if reqBody == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Primary: reasoning.effort
|
||||||
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||||
|
if effort, ok := reasoning["effort"].(string); ok {
|
||||||
|
return normalizeOpenAIReasoningEffort(effort), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: some clients may use a flat field.
|
||||||
|
if effort, ok := reqBody["reasoning_effort"].(string); ok {
|
||||||
|
return normalizeOpenAIReasoningEffort(effort), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveOpenAIReasoningEffortFromModel(model string) string {
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID := strings.TrimSpace(model)
|
||||||
|
if strings.Contains(modelID, "/") {
|
||||||
|
parts := strings.Split(modelID, "/")
|
||||||
|
modelID = parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case '-', '_', ' ':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||||
|
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &value
|
||||||
|
}
|
||||||
|
|
||||||
|
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &value
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIReasoningEffort(raw string) string {
|
||||||
|
value := strings.ToLower(strings.TrimSpace(raw))
|
||||||
|
if value == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize separators for "x-high"/"x_high" variants.
|
||||||
|
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
|
||||||
|
|
||||||
|
switch value {
|
||||||
|
case "none", "minimal":
|
||||||
|
return ""
|
||||||
|
case "low", "medium", "high":
|
||||||
|
return value
|
||||||
|
case "xhigh", "extrahigh":
|
||||||
|
return "xhigh"
|
||||||
|
default:
|
||||||
|
// Only store known effort levels for now to keep UI consistent.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
|
|||||||
skipDefaultLoad bool
|
skipDefaultLoad bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cancelReadCloser struct{}
|
||||||
|
|
||||||
|
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||||
|
func (c cancelReadCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
type failingGinWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
if c.acquireResults != nil {
|
if c.acquireResults != nil {
|
||||||
if result, ok := c.acquireResults[accountID]; ok {
|
if result, ok := c.acquireResults[accountID]; ok {
|
||||||
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||||
t.Fatalf("expected stream timeout error, got %v", err)
|
t.Fatalf("expected stream timeout error, got %v", err)
|
||||||
}
|
}
|
||||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: cancelReadCloser{},
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
_ = pr.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if result == nil || result.usage == nil {
|
||||||
|
t.Fatalf("expected usage result")
|
||||||
|
}
|
||||||
|
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
||||||
|
t.Fatalf("unexpected usage: %+v", *result.usage)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
|
|||||||
if !errors.Is(err, bufio.ErrTooLong) {
|
if !errors.Is(err, bufio.ErrTooLong) {
|
||||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||||
}
|
}
|
||||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
|
|
||||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||||
|
|
||||||
|
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
|
||||||
|
|
||||||
if acc.Platform != "" {
|
if acc.Platform != "" {
|
||||||
if _, ok := platform[acc.Platform]; !ok {
|
if _, ok := platform[acc.Platform]; !ok {
|
||||||
platform[acc.Platform] = &PlatformAvailability{
|
platform[acc.Platform] = &PlatformAvailability{
|
||||||
@@ -84,6 +86,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
if hasError {
|
if hasError {
|
||||||
p.ErrorCount++
|
p.ErrorCount++
|
||||||
}
|
}
|
||||||
|
if len(scopeRateLimits) > 0 {
|
||||||
|
if p.ScopeRateLimitCount == nil {
|
||||||
|
p.ScopeRateLimitCount = make(map[string]int64)
|
||||||
|
}
|
||||||
|
for scope := range scopeRateLimits {
|
||||||
|
p.ScopeRateLimitCount[scope]++
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, grp := range acc.Groups {
|
for _, grp := range acc.Groups {
|
||||||
@@ -108,6 +118,14 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
if hasError {
|
if hasError {
|
||||||
g.ErrorCount++
|
g.ErrorCount++
|
||||||
}
|
}
|
||||||
|
if len(scopeRateLimits) > 0 {
|
||||||
|
if g.ScopeRateLimitCount == nil {
|
||||||
|
g.ScopeRateLimitCount = make(map[string]int64)
|
||||||
|
}
|
||||||
|
for scope := range scopeRateLimits {
|
||||||
|
g.ScopeRateLimitCount[scope]++
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
displayGroupID := int64(0)
|
displayGroupID := int64(0)
|
||||||
@@ -140,6 +158,9 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
item.RateLimitRemainingSec = &remainingSec
|
item.RateLimitRemainingSec = &remainingSec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(scopeRateLimits) > 0 {
|
||||||
|
item.ScopeRateLimits = scopeRateLimits
|
||||||
|
}
|
||||||
if isOverloaded && acc.OverloadUntil != nil {
|
if isOverloaded && acc.OverloadUntil != nil {
|
||||||
item.OverloadUntil = acc.OverloadUntil
|
item.OverloadUntil = acc.OverloadUntil
|
||||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||||
|
|||||||
@@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
|||||||
return fmt.Errorf("query error counts: %w", err)
|
return fmt.Errorf("query error counts: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("query account switch counts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
windowSeconds := windowEnd.Sub(windowStart).Seconds()
|
windowSeconds := windowEnd.Sub(windowStart).Seconds()
|
||||||
if windowSeconds <= 0 {
|
if windowSeconds <= 0 {
|
||||||
windowSeconds = 60
|
windowSeconds = 60
|
||||||
@@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
|||||||
Upstream429Count: upstream429,
|
Upstream429Count: upstream429,
|
||||||
Upstream529Count: upstream529,
|
Upstream529Count: upstream529,
|
||||||
|
|
||||||
TokenConsumed: tokenConsumed,
|
TokenConsumed: tokenConsumed,
|
||||||
QPS: float64Ptr(roundTo1DP(qps)),
|
AccountSwitchCount: accountSwitchCount,
|
||||||
TPS: float64Ptr(roundTo1DP(tps)),
|
QPS: float64Ptr(roundTo1DP(qps)),
|
||||||
|
TPS: float64Ptr(roundTo1DP(tps)),
|
||||||
|
|
||||||
DurationP50Ms: duration.p50,
|
DurationP50Ms: duration.p50,
|
||||||
DurationP90Ms: duration.p90,
|
DurationP90Ms: duration.p90,
|
||||||
@@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2`
|
|||||||
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
|
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) {
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
COALESCE(SUM(CASE
|
||||||
|
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
|
||||||
|
ELSE 0
|
||||||
|
END), 0) AS switch_count
|
||||||
|
FROM ops_error_logs o
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements(
|
||||||
|
COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb)
|
||||||
|
) AS ev
|
||||||
|
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||||
|
AND o.is_count_tokens = FALSE`
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
type opsCollectedSystemStats struct {
|
type opsCollectedSystemStats struct {
|
||||||
cpuUsagePercent *float64
|
cpuUsagePercent *float64
|
||||||
memoryUsedMB *int64
|
memoryUsedMB *int64
|
||||||
|
|||||||
@@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct {
|
|||||||
Upstream429Count int64
|
Upstream429Count int64
|
||||||
Upstream529Count int64
|
Upstream529Count int64
|
||||||
|
|
||||||
TokenConsumed int64
|
TokenConsumed int64
|
||||||
|
AccountSwitchCount int64
|
||||||
|
|
||||||
QPS *float64
|
QPS *float64
|
||||||
TPS *float64
|
TPS *float64
|
||||||
@@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct {
|
|||||||
DBConnIdle *int `json:"db_conn_idle"`
|
DBConnIdle *int `json:"db_conn_idle"`
|
||||||
DBConnWaiting *int `json:"db_conn_waiting"`
|
DBConnWaiting *int `json:"db_conn_waiting"`
|
||||||
|
|
||||||
GoroutineCount *int `json:"goroutine_count"`
|
GoroutineCount *int `json:"goroutine_count"`
|
||||||
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
||||||
|
AccountSwitchCount *int64 `json:"account_switch_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpsUpsertJobHeartbeatInput struct {
|
type OpsUpsertJobHeartbeatInput struct {
|
||||||
|
|||||||
@@ -39,22 +39,24 @@ type AccountConcurrencyInfo struct {
|
|||||||
|
|
||||||
// PlatformAvailability aggregates account availability by platform.
|
// PlatformAvailability aggregates account availability by platform.
|
||||||
type PlatformAvailability struct {
|
type PlatformAvailability struct {
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
TotalAccounts int64 `json:"total_accounts"`
|
TotalAccounts int64 `json:"total_accounts"`
|
||||||
AvailableCount int64 `json:"available_count"`
|
AvailableCount int64 `json:"available_count"`
|
||||||
RateLimitCount int64 `json:"rate_limit_count"`
|
RateLimitCount int64 `json:"rate_limit_count"`
|
||||||
ErrorCount int64 `json:"error_count"`
|
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||||
|
ErrorCount int64 `json:"error_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupAvailability aggregates account availability by group.
|
// GroupAvailability aggregates account availability by group.
|
||||||
type GroupAvailability struct {
|
type GroupAvailability struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
GroupName string `json:"group_name"`
|
GroupName string `json:"group_name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
TotalAccounts int64 `json:"total_accounts"`
|
TotalAccounts int64 `json:"total_accounts"`
|
||||||
AvailableCount int64 `json:"available_count"`
|
AvailableCount int64 `json:"available_count"`
|
||||||
RateLimitCount int64 `json:"rate_limit_count"`
|
RateLimitCount int64 `json:"rate_limit_count"`
|
||||||
ErrorCount int64 `json:"error_count"`
|
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||||
|
ErrorCount int64 `json:"error_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountAvailability represents current availability for a single account.
|
// AccountAvailability represents current availability for a single account.
|
||||||
@@ -72,10 +74,11 @@ type AccountAvailability struct {
|
|||||||
IsOverloaded bool `json:"is_overloaded"`
|
IsOverloaded bool `json:"is_overloaded"`
|
||||||
HasError bool `json:"has_error"`
|
HasError bool `json:"has_error"`
|
||||||
|
|
||||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||||
OverloadUntil *time.Time `json:"overload_until"`
|
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
|
||||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
OverloadUntil *time.Time `json:"overload_until"`
|
||||||
ErrorMessage string `json:"error_message"`
|
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
ErrorMessage string `json:"error_message"`
|
||||||
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
@@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
attemptCtx := ctx
|
||||||
|
if switches > 0 {
|
||||||
|
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
|
||||||
|
}
|
||||||
exec := func() *opsRetryExecution {
|
exec := func() *opsRetryExecution {
|
||||||
defer selection.ReleaseFunc()
|
defer selection.ReleaseFunc()
|
||||||
return s.executeWithAccount(ctx, reqType, errorLog, body, account)
|
return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if exec != nil {
|
if exec != nil {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ type OpsThroughputTrendPoint struct {
|
|||||||
BucketStart time.Time `json:"bucket_start"`
|
BucketStart time.Time `json:"bucket_start"`
|
||||||
RequestCount int64 `json:"request_count"`
|
RequestCount int64 `json:"request_count"`
|
||||||
TokenConsumed int64 `json:"token_consumed"`
|
TokenConsumed int64 `json:"token_consumed"`
|
||||||
|
SwitchCount int64 `json:"switch_count"`
|
||||||
QPS float64 `json:"qps"`
|
QPS float64 `json:"qps"`
|
||||||
TPS float64 `json:"tps"`
|
TPS float64 `json:"tps"`
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user