feat(Sora): 完成Sora网关接入与媒体能力

新增 Sora 网关路由、账号调度与同步服务\n补充媒体代理与签名 URL、模型列表动态拉取\n完善计费配置、前端支持与相关测试
This commit is contained in:
yangjianbo
2026-01-31 20:22:22 +08:00
parent 99dc3b59bc
commit 618a614cbf
67 changed files with 4840 additions and 202 deletions

View File

@@ -300,6 +300,27 @@ default:
rate_multiplier: 1.0 rate_multiplier: 1.0
``` ```
### Sora 媒体签名 URL可选
当配置 `gateway.sora_media_signing_key``gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query
```yaml
gateway:
# /sora/media 是否强制要求 API Key默认 false
sora_media_require_api_key: false
# 媒体临时签名密钥(为空则禁用签名)
sora_media_signing_key: "your-signing-key"
# 临时签名 URL 有效期(秒)
sora_media_signed_url_ttl_seconds: 900
```
> 若未配置签名密钥,`/sora/media-signed` 将返回 503。
> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true仅允许携带 API Key 的 `/sora/media` 访问。
访问策略说明:
- `/sora/media`:内部调用或客户端携带 API Key 才能下载
- `/sora/media-signed`:外部可访问,但有签名 + 过期控制
`config.yaml` 还支持以下安全相关配置: `config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单 - `cors.allowed_origins` 配置 CORS 白名单

View File

@@ -87,10 +87,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerCache := repository.NewSchedulerCache(redisClient) schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache) accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
soraAccountRepository := repository.NewSoraAccountRepository(db) soraAccountRepository := repository.NewSoraAccountRepository(db)
sora2APIService := service.NewSora2APIService(configConfig)
sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -162,11 +164,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) modelHandler := admin.NewModelHandler(sora2APIService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -177,7 +182,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{

View File

@@ -52,6 +52,14 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"` ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field. // ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"` ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// SoraImagePrice360 holds the value of the "sora_image_price_360" field.
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
// SoraImagePrice540 holds the value of the "sora_image_price_540" field.
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
// SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// 是否仅允许 Claude Code 客户端 // 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID // 非 Claude Code 请求降级使用的分组 ID
@@ -170,7 +178,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte) values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
@@ -309,6 +317,34 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64) _m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64 *_m.ImagePrice4k = value.Float64
} }
case group.FieldSoraImagePrice360:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
} else if value.Valid {
_m.SoraImagePrice360 = new(float64)
*_m.SoraImagePrice360 = value.Float64
}
case group.FieldSoraImagePrice540:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
} else if value.Valid {
_m.SoraImagePrice540 = new(float64)
*_m.SoraImagePrice540 = value.Float64
}
case group.FieldSoraVideoPricePerRequest:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequest = new(float64)
*_m.SoraVideoPricePerRequest = value.Float64
}
case group.FieldSoraVideoPricePerRequestHd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequestHd = new(float64)
*_m.SoraVideoPricePerRequestHd = value.Float64
}
case group.FieldClaudeCodeOnly: case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok { if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -479,6 +515,26 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.SoraImagePrice360; v != nil {
builder.WriteString("sora_image_price_360=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraImagePrice540; v != nil {
builder.WriteString("sora_image_price_540=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequest; v != nil {
builder.WriteString("sora_video_price_per_request=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequestHd; v != nil {
builder.WriteString("sora_video_price_per_request_hd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("claude_code_only=") builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ") builder.WriteString(", ")

View File

@@ -49,6 +49,14 @@ const (
FieldImagePrice2k = "image_price_2k" FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database. // FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k" FieldImagePrice4k = "image_price_4k"
// FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
FieldSoraImagePrice360 = "sora_image_price_360"
// FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
FieldSoraImagePrice540 = "sora_image_price_540"
// FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only" FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -149,6 +157,10 @@ var Columns = []string{
FieldImagePrice1k, FieldImagePrice1k,
FieldImagePrice2k, FieldImagePrice2k,
FieldImagePrice4k, FieldImagePrice4k,
FieldSoraImagePrice360,
FieldSoraImagePrice540,
FieldSoraVideoPricePerRequest,
FieldSoraVideoPricePerRequestHd,
FieldClaudeCodeOnly, FieldClaudeCodeOnly,
FieldFallbackGroupID, FieldFallbackGroupID,
FieldModelRouting, FieldModelRouting,
@@ -307,6 +319,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
} }
// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
}
// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
}
// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
}
// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
}
// ByClaudeCodeOnly orders the results by the claude_code_only field. // ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()

View File

@@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
} }
// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
func SoraImagePrice360(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
func SoraImagePrice540(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
func SoraVideoPricePerRequest(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
func SoraVideoPricePerRequestHd(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group { func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1010,6 +1030,206 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
} }
// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
func SoraImagePrice360In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
func SoraImagePrice360GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
func SoraImagePrice360LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
}
// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
}
// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
func SoraImagePrice540In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
func SoraImagePrice540GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
func SoraImagePrice540LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
}
// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
}
// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
}
// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
}
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group { func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))

View File

@@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c return _c
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice360(v)
return _c
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice360(*v)
}
return _c
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice540(v)
return _c
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice540(*v)
}
return _c
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequest(v)
return _c
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequest(*v)
}
return _c
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequestHd(v)
return _c
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequestHd(*v)
}
return _c
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v) _c.mutation.SetClaudeCodeOnly(v)
@@ -632,6 +688,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value _node.ImagePrice4k = &value
} }
if value, ok := _c.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
_node.SoraImagePrice360 = &value
}
if value, ok := _c.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
_node.SoraImagePrice540 = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
_node.SoraVideoPricePerRequest = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
_node.SoraVideoPricePerRequestHd = &value
}
if value, ok := _c.mutation.ClaudeCodeOnly(); ok { if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value _node.ClaudeCodeOnly = value
@@ -1092,6 +1164,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u return u
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice360, v)
return u
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice360)
return u
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice360, v)
return u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice360)
return u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice540, v)
return u
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice540)
return u
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice540, v)
return u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice540)
return u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequest, v)
return u
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequest)
return u
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequest, v)
return u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequest)
return u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
return u
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequestHd)
return u
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v) u.Set(group.FieldClaudeCodeOnly, v)
@@ -1539,6 +1707,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
}) })
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {
@@ -2163,6 +2443,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
}) })
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {

View File

@@ -354,6 +354,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u return _u
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v) _u.mutation.SetClaudeCodeOnly(v)
@@ -817,6 +925,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() { if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
} }
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok { if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
} }
@@ -1472,6 +1616,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u return _u
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v) _u.mutation.SetClaudeCodeOnly(v)
@@ -1965,6 +2217,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() { if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
} }
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok { if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
} }

View File

@@ -224,6 +224,10 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
@@ -499,6 +503,7 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64}, {Name: "api_key_id", Type: field.TypeInt64},
{Name: "account_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64},
@@ -514,31 +519,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
@@ -547,32 +552,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[26]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
@@ -587,12 +592,12 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
}, },
}, },
} }

View File

@@ -3836,61 +3836,69 @@ func (m *AccountGroupMutation) ResetEdge(name string) error {
// GroupMutation represents an operation that mutates the Group nodes in the graph. // GroupMutation represents an operation that mutates the Group nodes in the graph.
type GroupMutation struct { type GroupMutation struct {
config config
op Op op Op
typ string typ string
id *int64 id *int64
created_at *time.Time created_at *time.Time
updated_at *time.Time updated_at *time.Time
deleted_at *time.Time deleted_at *time.Time
name *string name *string
description *string description *string
rate_multiplier *float64 rate_multiplier *float64
addrate_multiplier *float64 addrate_multiplier *float64
is_exclusive *bool is_exclusive *bool
status *string status *string
platform *string platform *string
subscription_type *string subscription_type *string
daily_limit_usd *float64 daily_limit_usd *float64
adddaily_limit_usd *float64 adddaily_limit_usd *float64
weekly_limit_usd *float64 weekly_limit_usd *float64
addweekly_limit_usd *float64 addweekly_limit_usd *float64
monthly_limit_usd *float64 monthly_limit_usd *float64
addmonthly_limit_usd *float64 addmonthly_limit_usd *float64
default_validity_days *int default_validity_days *int
adddefault_validity_days *int adddefault_validity_days *int
image_price_1k *float64 image_price_1k *float64
addimage_price_1k *float64 addimage_price_1k *float64
image_price_2k *float64 image_price_2k *float64
addimage_price_2k *float64 addimage_price_2k *float64
image_price_4k *float64 image_price_4k *float64
addimage_price_4k *float64 addimage_price_4k *float64
claude_code_only *bool sora_image_price_360 *float64
fallback_group_id *int64 addsora_image_price_360 *float64
addfallback_group_id *int64 sora_image_price_540 *float64
model_routing *map[string][]int64 addsora_image_price_540 *float64
model_routing_enabled *bool sora_video_price_per_request *float64
clearedFields map[string]struct{} addsora_video_price_per_request *float64
api_keys map[int64]struct{} sora_video_price_per_request_hd *float64
removedapi_keys map[int64]struct{} addsora_video_price_per_request_hd *float64
clearedapi_keys bool claude_code_only *bool
redeem_codes map[int64]struct{} fallback_group_id *int64
removedredeem_codes map[int64]struct{} addfallback_group_id *int64
clearedredeem_codes bool model_routing *map[string][]int64
subscriptions map[int64]struct{} model_routing_enabled *bool
removedsubscriptions map[int64]struct{} clearedFields map[string]struct{}
clearedsubscriptions bool api_keys map[int64]struct{}
usage_logs map[int64]struct{} removedapi_keys map[int64]struct{}
removedusage_logs map[int64]struct{} clearedapi_keys bool
clearedusage_logs bool redeem_codes map[int64]struct{}
accounts map[int64]struct{} removedredeem_codes map[int64]struct{}
removedaccounts map[int64]struct{} clearedredeem_codes bool
clearedaccounts bool subscriptions map[int64]struct{}
allowed_users map[int64]struct{} removedsubscriptions map[int64]struct{}
removedallowed_users map[int64]struct{} clearedsubscriptions bool
clearedallowed_users bool usage_logs map[int64]struct{}
done bool removedusage_logs map[int64]struct{}
oldValue func(context.Context) (*Group, error) clearedusage_logs bool
predicates []predicate.Group accounts map[int64]struct{}
removedaccounts map[int64]struct{}
clearedaccounts bool
allowed_users map[int64]struct{}
removedallowed_users map[int64]struct{}
clearedallowed_users bool
done bool
oldValue func(context.Context) (*Group, error)
predicates []predicate.Group
} }
var _ ent.Mutation = (*GroupMutation)(nil) var _ ent.Mutation = (*GroupMutation)(nil)
@@ -4873,6 +4881,286 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k) delete(m.clearedFields, group.FieldImagePrice4k)
} }
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (m *GroupMutation) SetSoraImagePrice360(f float64) {
m.sora_image_price_360 = &f
m.addsora_image_price_360 = nil
}
// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation.
func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) {
v := m.sora_image_price_360
if v == nil {
return
}
return *v, true
}
// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err)
}
return oldValue.SoraImagePrice360, nil
}
// AddSoraImagePrice360 adds f to the "sora_image_price_360" field.
func (m *GroupMutation) AddSoraImagePrice360(f float64) {
if m.addsora_image_price_360 != nil {
*m.addsora_image_price_360 += f
} else {
m.addsora_image_price_360 = &f
}
}
// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation.
func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) {
v := m.addsora_image_price_360
if v == nil {
return
}
return *v, true
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (m *GroupMutation) ClearSoraImagePrice360() {
m.sora_image_price_360 = nil
m.addsora_image_price_360 = nil
m.clearedFields[group.FieldSoraImagePrice360] = struct{}{}
}
// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation.
func (m *GroupMutation) SoraImagePrice360Cleared() bool {
_, ok := m.clearedFields[group.FieldSoraImagePrice360]
return ok
}
// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field.
func (m *GroupMutation) ResetSoraImagePrice360() {
m.sora_image_price_360 = nil
m.addsora_image_price_360 = nil
delete(m.clearedFields, group.FieldSoraImagePrice360)
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (m *GroupMutation) SetSoraImagePrice540(f float64) {
m.sora_image_price_540 = &f
m.addsora_image_price_540 = nil
}
// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation.
func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) {
v := m.sora_image_price_540
if v == nil {
return
}
return *v, true
}
// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err)
}
return oldValue.SoraImagePrice540, nil
}
// AddSoraImagePrice540 adds f to the "sora_image_price_540" field.
func (m *GroupMutation) AddSoraImagePrice540(f float64) {
if m.addsora_image_price_540 != nil {
*m.addsora_image_price_540 += f
} else {
m.addsora_image_price_540 = &f
}
}
// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation.
func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) {
v := m.addsora_image_price_540
if v == nil {
return
}
return *v, true
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (m *GroupMutation) ClearSoraImagePrice540() {
m.sora_image_price_540 = nil
m.addsora_image_price_540 = nil
m.clearedFields[group.FieldSoraImagePrice540] = struct{}{}
}
// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation.
func (m *GroupMutation) SoraImagePrice540Cleared() bool {
_, ok := m.clearedFields[group.FieldSoraImagePrice540]
return ok
}
// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field.
func (m *GroupMutation) ResetSoraImagePrice540() {
m.sora_image_price_540 = nil
m.addsora_image_price_540 = nil
delete(m.clearedFields, group.FieldSoraImagePrice540)
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) {
m.sora_video_price_per_request = &f
m.addsora_video_price_per_request = nil
}
// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation.
func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) {
v := m.sora_video_price_per_request
if v == nil {
return
}
return *v, true
}
// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err)
}
return oldValue.SoraVideoPricePerRequest, nil
}
// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field.
func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) {
if m.addsora_video_price_per_request != nil {
*m.addsora_video_price_per_request += f
} else {
m.addsora_video_price_per_request = &f
}
}
// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation.
func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) {
v := m.addsora_video_price_per_request
if v == nil {
return
}
return *v, true
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (m *GroupMutation) ClearSoraVideoPricePerRequest() {
m.sora_video_price_per_request = nil
m.addsora_video_price_per_request = nil
m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{}
}
// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation.
func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool {
_, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest]
return ok
}
// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field.
func (m *GroupMutation) ResetSoraVideoPricePerRequest() {
m.sora_video_price_per_request = nil
m.addsora_video_price_per_request = nil
delete(m.clearedFields, group.FieldSoraVideoPricePerRequest)
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) {
m.sora_video_price_per_request_hd = &f
m.addsora_video_price_per_request_hd = nil
}
// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation.
func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) {
v := m.sora_video_price_per_request_hd
if v == nil {
return
}
return *v, true
}
// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err)
}
return oldValue.SoraVideoPricePerRequestHd, nil
}
// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field.
func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) {
if m.addsora_video_price_per_request_hd != nil {
*m.addsora_video_price_per_request_hd += f
} else {
m.addsora_video_price_per_request_hd = &f
}
}
// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation.
func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) {
v := m.addsora_video_price_per_request_hd
if v == nil {
return
}
return *v, true
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() {
m.sora_video_price_per_request_hd = nil
m.addsora_video_price_per_request_hd = nil
m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{}
}
// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation.
func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool {
_, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd]
return ok
}
// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field.
func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() {
m.sora_video_price_per_request_hd = nil
m.addsora_video_price_per_request_hd = nil
delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd)
}
// SetClaudeCodeOnly sets the "claude_code_only" field. // SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) { func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b m.claude_code_only = &b
@@ -5422,7 +5710,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 21) fields := make([]string, 0, 25)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
@@ -5474,6 +5762,18 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil { if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.sora_image_price_360 != nil {
fields = append(fields, group.FieldSoraImagePrice360)
}
if m.sora_image_price_540 != nil {
fields = append(fields, group.FieldSoraImagePrice540)
}
if m.sora_video_price_per_request != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequest)
}
if m.sora_video_price_per_request_hd != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
}
if m.claude_code_only != nil { if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly) fields = append(fields, group.FieldClaudeCodeOnly)
} }
@@ -5528,6 +5828,14 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k() return m.ImagePrice2k()
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.ImagePrice4k() return m.ImagePrice4k()
case group.FieldSoraImagePrice360:
return m.SoraImagePrice360()
case group.FieldSoraImagePrice540:
return m.SoraImagePrice540()
case group.FieldSoraVideoPricePerRequest:
return m.SoraVideoPricePerRequest()
case group.FieldSoraVideoPricePerRequestHd:
return m.SoraVideoPricePerRequestHd()
case group.FieldClaudeCodeOnly: case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly() return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
@@ -5579,6 +5887,14 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx) return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx) return m.OldImagePrice4k(ctx)
case group.FieldSoraImagePrice360:
return m.OldSoraImagePrice360(ctx)
case group.FieldSoraImagePrice540:
return m.OldSoraImagePrice540(ctx)
case group.FieldSoraVideoPricePerRequest:
return m.OldSoraVideoPricePerRequest(ctx)
case group.FieldSoraVideoPricePerRequestHd:
return m.OldSoraVideoPricePerRequestHd(ctx)
case group.FieldClaudeCodeOnly: case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx) return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
@@ -5715,6 +6031,34 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetImagePrice4k(v) m.SetImagePrice4k(v)
return nil return nil
case group.FieldSoraImagePrice360:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSoraImagePrice360(v)
return nil
case group.FieldSoraImagePrice540:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSoraImagePrice540(v)
return nil
case group.FieldSoraVideoPricePerRequest:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSoraVideoPricePerRequest(v)
return nil
case group.FieldSoraVideoPricePerRequestHd:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSoraVideoPricePerRequestHd(v)
return nil
case group.FieldClaudeCodeOnly: case group.FieldClaudeCodeOnly:
v, ok := value.(bool) v, ok := value.(bool)
if !ok { if !ok {
@@ -5775,6 +6119,18 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil { if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.addsora_image_price_360 != nil {
fields = append(fields, group.FieldSoraImagePrice360)
}
if m.addsora_image_price_540 != nil {
fields = append(fields, group.FieldSoraImagePrice540)
}
if m.addsora_video_price_per_request != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequest)
}
if m.addsora_video_price_per_request_hd != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
}
if m.addfallback_group_id != nil { if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
@@ -5802,6 +6158,14 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k() return m.AddedImagePrice2k()
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.AddedImagePrice4k() return m.AddedImagePrice4k()
case group.FieldSoraImagePrice360:
return m.AddedSoraImagePrice360()
case group.FieldSoraImagePrice540:
return m.AddedSoraImagePrice540()
case group.FieldSoraVideoPricePerRequest:
return m.AddedSoraVideoPricePerRequest()
case group.FieldSoraVideoPricePerRequestHd:
return m.AddedSoraVideoPricePerRequestHd()
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID() return m.AddedFallbackGroupID()
} }
@@ -5869,6 +6233,34 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
} }
m.AddImagePrice4k(v) m.AddImagePrice4k(v)
return nil return nil
case group.FieldSoraImagePrice360:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddSoraImagePrice360(v)
return nil
case group.FieldSoraImagePrice540:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddSoraImagePrice540(v)
return nil
case group.FieldSoraVideoPricePerRequest:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddSoraVideoPricePerRequest(v)
return nil
case group.FieldSoraVideoPricePerRequestHd:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddSoraVideoPricePerRequestHd(v)
return nil
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
v, ok := value.(int64) v, ok := value.(int64)
if !ok { if !ok {
@@ -5908,6 +6300,18 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) { if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.FieldCleared(group.FieldSoraImagePrice360) {
fields = append(fields, group.FieldSoraImagePrice360)
}
if m.FieldCleared(group.FieldSoraImagePrice540) {
fields = append(fields, group.FieldSoraImagePrice540)
}
if m.FieldCleared(group.FieldSoraVideoPricePerRequest) {
fields = append(fields, group.FieldSoraVideoPricePerRequest)
}
if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) {
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
}
if m.FieldCleared(group.FieldFallbackGroupID) { if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID) fields = append(fields, group.FieldFallbackGroupID)
} }
@@ -5952,6 +6356,18 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
m.ClearImagePrice4k() m.ClearImagePrice4k()
return nil return nil
case group.FieldSoraImagePrice360:
m.ClearSoraImagePrice360()
return nil
case group.FieldSoraImagePrice540:
m.ClearSoraImagePrice540()
return nil
case group.FieldSoraVideoPricePerRequest:
m.ClearSoraVideoPricePerRequest()
return nil
case group.FieldSoraVideoPricePerRequestHd:
m.ClearSoraVideoPricePerRequestHd()
return nil
case group.FieldFallbackGroupID: case group.FieldFallbackGroupID:
m.ClearFallbackGroupID() m.ClearFallbackGroupID()
return nil return nil
@@ -6017,6 +6433,18 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
m.ResetImagePrice4k() m.ResetImagePrice4k()
return nil return nil
case group.FieldSoraImagePrice360:
m.ResetSoraImagePrice360()
return nil
case group.FieldSoraImagePrice540:
m.ResetSoraImagePrice540()
return nil
case group.FieldSoraVideoPricePerRequest:
m.ResetSoraVideoPricePerRequest()
return nil
case group.FieldSoraVideoPricePerRequestHd:
m.ResetSoraVideoPricePerRequestHd()
return nil
case group.FieldClaudeCodeOnly: case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly() m.ResetClaudeCodeOnly()
return nil return nil
@@ -11504,6 +11932,7 @@ type UsageLogMutation struct {
image_count *int image_count *int
addimage_count *int addimage_count *int
image_size *string image_size *string
media_type *string
created_at *time.Time created_at *time.Time
clearedFields map[string]struct{} clearedFields map[string]struct{}
user *int64 user *int64
@@ -13130,6 +13559,55 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize) delete(m.clearedFields, usagelog.FieldImageSize)
} }
// SetMediaType sets the "media_type" field.
func (m *UsageLogMutation) SetMediaType(s string) {
m.media_type = &s
}
// MediaType returns the value of the "media_type" field in the mutation.
func (m *UsageLogMutation) MediaType() (r string, exists bool) {
v := m.media_type
if v == nil {
return
}
return *v, true
}
// OldMediaType returns the old "media_type" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMediaType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMediaType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMediaType: %w", err)
}
return oldValue.MediaType, nil
}
// ClearMediaType clears the value of the "media_type" field.
func (m *UsageLogMutation) ClearMediaType() {
m.media_type = nil
m.clearedFields[usagelog.FieldMediaType] = struct{}{}
}
// MediaTypeCleared returns if the "media_type" field was cleared in this mutation.
func (m *UsageLogMutation) MediaTypeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldMediaType]
return ok
}
// ResetMediaType resets all changes to the "media_type" field.
func (m *UsageLogMutation) ResetMediaType() {
m.media_type = nil
delete(m.clearedFields, usagelog.FieldMediaType)
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (m *UsageLogMutation) SetCreatedAt(t time.Time) { func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
m.created_at = &t m.created_at = &t
@@ -13335,7 +13813,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 30) fields := make([]string, 0, 31)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
@@ -13423,6 +13901,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil { if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize) fields = append(fields, usagelog.FieldImageSize)
} }
if m.media_type != nil {
fields = append(fields, usagelog.FieldMediaType)
}
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, usagelog.FieldCreatedAt) fields = append(fields, usagelog.FieldCreatedAt)
} }
@@ -13492,6 +13973,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount() return m.ImageCount()
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
return m.ImageSize() return m.ImageSize()
case usagelog.FieldMediaType:
return m.MediaType()
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
return m.CreatedAt() return m.CreatedAt()
} }
@@ -13561,6 +14044,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx) return m.OldImageCount(ctx)
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
return m.OldImageSize(ctx) return m.OldImageSize(ctx)
case usagelog.FieldMediaType:
return m.OldMediaType(ctx)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
return m.OldCreatedAt(ctx) return m.OldCreatedAt(ctx)
} }
@@ -13775,6 +14260,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetImageSize(v) m.SetImageSize(v)
return nil return nil
case usagelog.FieldMediaType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMediaType(v)
return nil
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
v, ok := value.(time.Time) v, ok := value.(time.Time)
if !ok { if !ok {
@@ -14055,6 +14547,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) { if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize) fields = append(fields, usagelog.FieldImageSize)
} }
if m.FieldCleared(usagelog.FieldMediaType) {
fields = append(fields, usagelog.FieldMediaType)
}
return fields return fields
} }
@@ -14093,6 +14588,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
m.ClearImageSize() m.ClearImageSize()
return nil return nil
case usagelog.FieldMediaType:
m.ClearMediaType()
return nil
} }
return fmt.Errorf("unknown UsageLog nullable field %s", name) return fmt.Errorf("unknown UsageLog nullable field %s", name)
} }
@@ -14188,6 +14686,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
m.ResetImageSize() m.ResetImageSize()
return nil return nil
case usagelog.FieldMediaType:
m.ResetMediaType()
return nil
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
m.ResetCreatedAt() m.ResetCreatedAt()
return nil return nil

View File

@@ -278,11 +278,11 @@ func init() {
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor() groupDescClaudeCodeOnly := groupFields[18].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor() groupDescModelRoutingEnabled := groupFields[21].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields() promocodeFields := schema.PromoCode{}.Fields()
@@ -647,8 +647,12 @@ func init() {
usagelogDescImageSize := usagelogFields[28].Descriptor() usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[29].Descriptor() usagelogDescCreatedAt := usagelogFields[30].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()

View File

@@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field {
Nillable(). Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_image_price_540").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request_hd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029) // Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only"). field.Bool("claude_code_only").
Default(false). Default(false).

View File

@@ -118,6 +118,11 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10). MaxLen(10).
Optional(). Optional().
Nillable(), Nillable(),
// 媒体类型字段sora 使用)
field.String("media_type").
MaxLen(16).
Optional().
Nillable(),
// 时间戳(只有 created_at日志不可修改 // 时间戳(只有 created_at日志不可修改
field.Time("created_at"). field.Time("created_at").

View File

@@ -80,6 +80,8 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field. // ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"` ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
@@ -171,7 +173,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@@ -378,6 +380,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string) _m.ImageSize = new(string)
*_m.ImageSize = value.String *_m.ImageSize = value.String
} }
case usagelog.FieldMediaType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field media_type", values[i])
} else if value.Valid {
_m.MediaType = new(string)
*_m.MediaType = value.String
}
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok { if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i]) return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -548,6 +557,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.MediaType; v != nil {
builder.WriteString("media_type=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("created_at=") builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')') builder.WriteByte(')')

View File

@@ -72,6 +72,8 @@ const (
FieldImageCount = "image_count" FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database. // FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size" FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
@@ -155,6 +157,7 @@ var Columns = []string{
FieldIPAddress, FieldIPAddress,
FieldImageCount, FieldImageCount,
FieldImageSize, FieldImageSize,
FieldMediaType,
FieldCreatedAt, FieldCreatedAt,
} }
@@ -211,6 +214,8 @@ var (
DefaultImageCount int DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time DefaultCreatedAt func() time.Time
) )
@@ -368,6 +373,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc() return sql.OrderByField(FieldImageSize, opts...).ToFunc()
} }
// ByMediaType orders the results by the media_type field.
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field. // ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()

View File

@@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
} }
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageLog { func CreatedAt(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
@@ -1440,6 +1445,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
} }
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
func MediaTypeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
func MediaTypeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
}
// MediaTypeIn applies the In predicate on the "media_type" field.
func MediaTypeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
}
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
}
// MediaTypeGT applies the GT predicate on the "media_type" field.
func MediaTypeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
}
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
func MediaTypeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
}
// MediaTypeLT applies the LT predicate on the "media_type" field.
func MediaTypeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
}
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
func MediaTypeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
}
// MediaTypeContains applies the Contains predicate on the "media_type" field.
func MediaTypeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
}
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
func MediaTypeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
}
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
func MediaTypeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
}
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
func MediaTypeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
}
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
func MediaTypeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
}
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
func MediaTypeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
}
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageLog { func CreatedAtEQ(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))

View File

@@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c return _c
} }
// SetMediaType sets the "media_type" field.
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
_c.mutation.SetMediaType(v)
return _c
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
if v != nil {
_c.SetMediaType(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
@@ -627,6 +641,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _c.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
} }
@@ -762,6 +781,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value _node.ImageSize = &value
} }
if value, ok := _c.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value
}
if value, ok := _c.mutation.CreatedAt(); ok { if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value _node.CreatedAt = value
@@ -1407,6 +1430,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u return u
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
u.Set(usagelog.FieldMediaType, v)
return u
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldMediaType)
return u
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
u.SetNull(usagelog.FieldMediaType)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
@@ -2040,6 +2081,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
}) })
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
@@ -2839,6 +2901,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
}) })
} }
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {

View File

@@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u return _u
} }
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
_u.mutation.ClearMediaType()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -726,6 +746,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
} }
@@ -894,6 +919,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() { if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString) _spec.ClearField(usagelog.FieldImageSize, field.TypeString)
} }
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
@@ -1639,6 +1670,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u return _u
} }
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
_u.mutation.ClearMediaType()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -1766,6 +1817,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
} }
} }
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
} }
@@ -1951,6 +2007,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() { if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString) _spec.ClearField(usagelog.FieldImageSize, field.TypeString)
} }
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,

View File

@@ -58,6 +58,7 @@ type Config struct {
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Sora2API Sora2APIConfig `mapstructure:"sora2api"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
@@ -204,6 +205,24 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"` PingInterval int `mapstructure:"ping_interval"`
} }
// Sora2APIConfig Sora2API 服务配置
type Sora2APIConfig struct {
// BaseURL Sora2API 服务地址(例如 http://localhost:8000
BaseURL string `mapstructure:"base_url"`
// APIKey Sora2API OpenAI 兼容接口的 API Key
APIKey string `mapstructure:"api_key"`
// AdminUsername 管理员用户名(用于 token 同步)
AdminUsername string `mapstructure:"admin_username"`
// AdminPassword 管理员密码(用于 token 同步)
AdminPassword string `mapstructure:"admin_password"`
// AdminTokenTTLSeconds 管理员 Token 缓存时长(秒)
AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"`
// AdminTimeoutSeconds 管理接口请求超时(秒)
AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"`
// TokenImportMode token 导入模式at/offline
TokenImportMode string `mapstructure:"token_import_mode"`
}
// GatewayConfig API网关相关配置 // GatewayConfig API网关相关配置
type GatewayConfig struct { type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时 // 等待上游响应头的超时时间0表示无超时
@@ -258,6 +277,24 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义 // 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"` FailoverOn400 bool `mapstructure:"failover_on_400"`
// Sora 专用配置
// SoraMaxBodySize: Sora 请求体最大字节数0 表示使用 gateway.max_body_size
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
// SoraStreamTimeoutSeconds: Sora 流式请求总超时0 表示不限制)
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
// SoraRequestTimeoutSeconds: Sora 非流式请求超时0 表示不限制)
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
// SoraStreamMode: stream 强制策略force/error
SoraStreamMode string `mapstructure:"sora_stream_mode"`
// SoraModelFilters: 模型列表过滤配置
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"` MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格) // Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
@@ -273,6 +310,12 @@ type GatewayConfig struct {
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
} }
// SoraModelFiltersConfig Sora 模型过滤配置
type SoraModelFiltersConfig struct {
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
}
// TLSFingerprintConfig TLS指纹伪装配置 // TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 // 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type TLSFingerprintConfig struct { type TLSFingerprintConfig struct {
@@ -823,6 +866,13 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
viper.SetDefault("gateway.sora_stream_mode", "force")
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
viper.SetDefault("gateway.sora_media_require_api_key", true)
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
@@ -869,6 +919,15 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "") viper.SetDefault("gemini.quota.policy", "")
// Sora2API
viper.SetDefault("sora2api.base_url", "")
viper.SetDefault("sora2api.api_key", "")
viper.SetDefault("sora2api.admin_username", "")
viper.SetDefault("sora2api.admin_password", "")
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
viper.SetDefault("sora2api.token_import_mode", "at")
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -1085,6 +1144,25 @@ func (c *Config) Validate() error {
if c.Gateway.MaxBodySize <= 0 { if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive") return fmt.Errorf("gateway.max_body_size must be positive")
} }
if c.Gateway.SoraMaxBodySize < 0 {
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
}
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
}
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
}
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
}
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
switch mode {
case "force", "error":
default:
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
}
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation { switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
@@ -1181,6 +1259,25 @@ func (c *Config) Validate() error {
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
} }
if strings.TrimSpace(c.Sora2API.BaseURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
return fmt.Errorf("sora2api.base_url invalid: %w", err)
}
warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL)
}
if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" {
switch mode {
case "at", "offline":
default:
return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline")
}
}
if c.Sora2API.AdminTokenTTLSeconds < 0 {
return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative")
}
if c.Sora2API.AdminTimeoutSeconds < 0 {
return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative")
}
if c.Ops.MetricsCollectorCache.TTL < 0 { if c.Ops.MetricsCollectorCache.TTL < 0 {
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
} }

View File

@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"` IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
@@ -49,7 +53,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct { type UpdateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"` RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"` IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -61,6 +65,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
@@ -167,6 +175,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
@@ -209,6 +221,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,

View File

@@ -0,0 +1,55 @@
package admin
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ModelHandler handles admin model listing requests.
type ModelHandler struct {
sora2apiService *service.Sora2APIService
}
// NewModelHandler creates a new ModelHandler.
func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler {
return &ModelHandler{
sora2apiService: sora2apiService,
}
}
// List handles listing models for a specific platform
// GET /api/v1/admin/models?platform=sora
func (h *ModelHandler) List(c *gin.Context) {
platform := strings.TrimSpace(strings.ToLower(c.Query("platform")))
if platform == "" {
response.BadRequest(c, "platform is required")
return
}
switch platform {
case service.PlatformSora:
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
response.Error(c, http.StatusServiceUnavailable, "sora2api not configured")
return
}
models, err := h.sora2apiService.ListModels(c.Request.Context())
if err != nil {
response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models")
return
}
ids := make([]string, 0, len(models))
for _, m := range models {
if strings.TrimSpace(m.ID) != "" {
ids = append(ids, m.ID)
}
}
response.Success(c, ids)
default:
response.BadRequest(c, "unsupported platform")
}
}

View File

@@ -0,0 +1,87 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func TestModelHandlerListSoraSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`))
}))
t.Cleanup(upstream.Close)
cfg := &config.Config{}
cfg.Sora2API.BaseURL = upstream.URL
cfg.Sora2API.APIKey = "test-key"
soraService := service.NewSora2APIService(cfg)
h := NewModelHandler(soraService)
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
var resp response.Response
if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if resp.Code != 0 {
t.Fatalf("响应 code=%d", resp.Code)
}
data, ok := resp.Data.([]any)
if !ok {
t.Fatalf("响应 data 类型错误")
}
if len(data) != 2 {
t.Fatalf("模型数量不符: %d", len(data))
}
}
func TestModelHandlerListSoraNotConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}
func TestModelHandlerListInvalidPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}

View File

@@ -136,6 +136,10 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
@@ -379,6 +383,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount, ImageCount: l.ImageCount,
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent, UserAgent: l.UserAgent,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),

View File

@@ -61,6 +61,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
@@ -246,6 +252,7 @@ type UsageLog struct {
// 图片生成字段 // 图片生成字段
ImageCount int `json:"image_count"` ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"` ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`

View File

@@ -29,6 +29,7 @@ type GatewayHandler struct {
geminiCompatService *service.GeminiMessagesCompatService geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService userService *service.UserService
sora2apiService *service.Sora2APIService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
@@ -41,6 +42,7 @@ func NewGatewayHandler(
geminiCompatService *service.GeminiMessagesCompatService, geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService, antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService, userService *service.UserService,
sora2apiService *service.Sora2APIService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
cfg *config.Config, cfg *config.Config,
@@ -62,6 +64,7 @@ func NewGatewayHandler(
geminiCompatService: geminiCompatService, geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
sora2apiService: sora2apiService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
@@ -478,6 +481,26 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID = &apiKey.Group.ID groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
platform = forcedPlatform
}
if platform == service.PlatformSora {
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured")
return
}
models, err := h.sora2apiService.ListModels(c.Request.Context())
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models")
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Get available models from account configurations (without platform filter) // Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")

View File

@@ -23,6 +23,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
Model *admin.ModelHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
@@ -36,6 +37,7 @@ type Handlers struct {
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler Setting *SettingHandler
} }

View File

@@ -0,0 +1,474 @@
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"path"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
sora2apiBaseURL string
soraMediaSigningKey string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
signKey := ""
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
}
baseURL := ""
if cfg != nil {
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
sora2apiBaseURL: baseURL,
soraMediaSigningKey: signKey,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel, _ := reqBody["model"].(string)
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqMessages, _ := reqBody["messages"].([]any)
if len(reqMessages) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream, _ := reqBody["stream"].(bool)
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
reqBody["stream"] = true
updated, err := json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
body = updated
}
setOpsRequestContext(c, reqModel, clientStream, body)
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, reqBody)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
if err != nil {
log.Printf("[Sora Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
return
}
}
func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy proxies /tmp or /static media files from sora2api
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned proxies /tmp or /static media files with signature verification
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
if h.sora2apiBaseURL == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "sora2api 未配置",
},
})
return
}
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
targetURL := h.sora2apiBaseURL + cleaned
if rawQuery := query.Encode(); rawQuery != "" {
targetURL += "?" + rawQuery
}
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil)
if err != nil {
c.Status(http.StatusBadGateway)
return
}
copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"}
for _, key := range copyHeaders {
if val := c.GetHeader(key); val != "" {
req.Header.Set(key, val)
}
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Status(http.StatusBadGateway)
return
}
defer func() { _ = resp.Body.Close() }()
for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} {
if val := resp.Header.Get(key); val != "" {
c.Header(key, val)
}
}
c.Status(resp.StatusCode)
_, _ = io.Copy(c.Writer, resp.Body)
}

View File

@@ -26,6 +26,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
modelHandler *admin.ModelHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
@@ -45,6 +46,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
Model: modelHandler,
} }
} }
@@ -69,6 +71,7 @@ func ProvideHandlers(
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
@@ -81,6 +84,7 @@ func ProvideHandlers(
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
} }
} }
@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler, NewSubscriptionHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
@@ -116,6 +121,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewModelHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,

View File

@@ -13,6 +13,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -38,9 +39,7 @@ type TLSInfo struct {
// TestDialerBasicConnection tests that the dialer can establish TLS connections. // TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) { func TestDialerBasicConnection(t *testing.T) {
if testing.Short() { skipNetworkTest(t)
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile // Create a dialer with default profile
profile := &Profile{ profile := &Profile{
@@ -74,10 +73,7 @@ func TestDialerBasicConnection(t *testing.T) {
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) // Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) { func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode skipNetworkTest(t)
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{ profile := &Profile{
Name: "Claude CLI Test", Name: "Claude CLI Test",
@@ -178,6 +174,15 @@ func TestJA3Fingerprint(t *testing.T) {
} }
} }
func skipNetworkTest(t *testing.T) {
if testing.Short() {
t.Skip("跳过网络测试short 模式)")
}
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints. // TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) { func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles // Create two dialers with different profiles
@@ -317,9 +322,7 @@ type TestProfileExpectation struct {
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... // Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) { func TestAllProfiles(t *testing.T) {
if testing.Short() { skipNetworkTest(t)
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints // Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles // These profiles are from config.yaml gateway.tls_fingerprint.profiles

View File

@@ -134,6 +134,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k, group.FieldImagePrice1k,
group.FieldImagePrice2k, group.FieldImagePrice2k,
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled, group.FieldModelRoutingEnabled,
@@ -421,6 +425,10 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k, ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,

View File

@@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
@@ -106,6 +110,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
@@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5,
@@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress) ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress, ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
mediaType,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
mediaType sql.NullString
createdAt time.Time createdAt time.Time
) )
@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress, &ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&mediaType,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
if mediaType.Valid {
log.MediaType = &mediaType.String
}
return log, nil return log, nil
} }

View File

@@ -64,6 +64,9 @@ func RegisterAdminRoutes(
// 用户属性管理 // 用户属性管理
registerUserAttributeRoutes(admin, h) registerUserAttributeRoutes(admin, h)
// 模型列表
registerModelRoutes(admin, h)
} }
} }
@@ -371,3 +374,7 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
} }
} }
func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
admin.GET("/models", h.Admin.Model.List)
}

View File

@@ -20,6 +20,11 @@ func RegisterGatewayRoutes(
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
@@ -38,6 +43,16 @@ func RegisterGatewayRoutes(
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
} }
// Sora Chat Completions
soraGateway := r.Group("/v1")
soraGateway.Use(soraBodyLimit)
soraGateway.Use(clientRequestID)
soraGateway.Use(opsErrorLogger)
soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
{
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
@@ -82,4 +97,25 @@ func RegisterGatewayRoutes(
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL无需 API Key
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
} }

View File

@@ -102,11 +102,16 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 // Sora 按次计费配置
FallbackGroupID *int64 // 降级分组 ID SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由 ModelRoutingEnabled bool // 是否启用模型路由
@@ -124,11 +129,16 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 // Sora 按次计费配置
FallbackGroupID *int64 // 降级分组 ID SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由 ModelRoutingEnabled *bool // 是否启用模型路由
@@ -273,6 +283,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
soraSyncService *Sora2APISyncService // Sora2API 同步服务
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
@@ -288,6 +299,7 @@ func NewAdminService(
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, soraAccountRepo SoraAccountRepository,
soraSyncService *Sora2APISyncService,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
@@ -301,6 +313,7 @@ func NewAdminService(
groupRepo: groupRepo, groupRepo: groupRepo,
accountRepo: accountRepo, accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo, soraAccountRepo: soraAccountRepo,
soraSyncService: soraSyncService,
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
@@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K) imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组 // 校验降级分组
if input.FallbackGroupID != nil { if input.FallbackGroupID != nil {
@@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
} }
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Platform: platform, Platform: platform,
RateMultiplier: input.RateMultiplier, RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, SoraImagePrice360: soraImagePrice360,
FallbackGroupID: input.FallbackGroupID, SoraImagePrice540: soraImagePrice540,
ModelRouting: input.ModelRouting, SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
@@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil { if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K) group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
} }
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
// Claude Code 客户端限制 // Claude Code 客户端限制
if input.ClaudeCodeOnly != nil { if input.ClaudeCodeOnly != nil {
@@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
} }
// 同步到 sora2api异步不阻塞创建
s.syncSoraAccountAsync(account)
return account, nil return account, nil
} }
@@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
} }
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id) updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
} }
// BulkUpdateAccounts updates multiple accounts in one request. // BulkUpdateAccounts updates multiple accounts in one request.
@@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil return result, nil
} }
// Preload account platforms for mixed channel risk checks if group bindings are requested. needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
needSoraSync := s != nil && s.soraSyncService != nil
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{} platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck { if needMixedChannelCheck || needSoraSync {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil { if err != nil {
return nil, err if needMixedChannelCheck {
} return nil, err
for _, account := range accounts { }
if account != nil { log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
platformByID[account.ID] = account.Platform } else {
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
} }
} }
} }
@@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Success++ result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID) result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry) result.Results = append(result.Results, entry)
// 批量更新后同步 sora2api
if needSoraSync {
platform := platformByID[accountID]
if platform == "" {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
if updated.Platform == PlatformSora {
s.syncSoraAccountAsync(updated)
}
continue
}
if platform == PlatformSora {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
s.syncSoraAccountAsync(updated)
}
}
} }
return result, nil return result, nil
} }
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return err
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
}
s.deleteSoraAccountAsync(account)
return nil
} }
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
@@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
} }
return s.accountRepo.GetByID(ctx, id) updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
} }
// Proxy management implementations // Proxy management implementations

View File

@@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error bulkUpdateErr error
bulkUpdateIDs []int64 bulkUpdateIDs []int64
bindGroupErrByID map[int64]error bindGroupErrByID map[int64]error
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
getByIDsIDs []int64
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
} }
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil return nil
} }
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
if s.getByIDsErr != nil {
return nil, s.getByIDsErr
}
return s.getByIDsAccounts, nil
}
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
s.getByIDCalled = append(s.getByIDCalled, id)
if err, ok := s.getByIDErrByID[id]; ok {
return nil, err
}
if account, ok := s.getByIDAccounts[id]; ok {
return account, nil
}
return nil, errors.New("account not found")
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{} repo := &accountRepoStubForBulkUpdate{}
@@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3) require.Len(t, result.Results, 3)
} }
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformSora},
},
getByIDAccounts: map[int64]*Account{
1: {ID: 1, Platform: PlatformSora},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
soraSyncService: &Sora2APISyncService{},
}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.True(t, repo.getByIDsCalled)
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
}

View File

@@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`

View File

@@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting, ModelRouting: apiKey.Group.ModelRouting,
@@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting, ModelRouting: snapshot.Group.ModelRouting,

View File

@@ -303,6 +303,14 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格nil 表示使用默认值) Price4K *float64 // 4K 尺寸价格nil 表示使用默认值)
} }
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用 // CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格) // model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K" // imageSize: 图片尺寸 "1K", "2K", "4K"
@@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
} }
} }
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价 // getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格 // 优先使用分组配置的价格

View File

@@ -184,6 +184,10 @@ type ForwardResult struct {
// 图片生成计费字段(仅 gemini-3-pro-image 使用) // 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K" ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown var cost *CostBreakdown
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.ImageCount > 0 { if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
}
} else if result.ImageCount > 0 {
// 图片生成计费 // 图片生成计费
var groupConfig *ImagePriceConfig var groupConfig *ImagePriceConfig
if apiKey.Group != nil { if apiKey.Group != nil {
@@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" { if result.ImageSize != "" {
imageSize = &result.ImageSize imageSize = &result.ImageSize
} }
var mediaType *string
if strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
@@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: imageSize,
MediaType: mediaType,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }

View File

@@ -26,6 +26,12 @@ type Group struct {
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
// Sora 按次计费配置(阶段 1
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool ClaudeCodeOnly bool
FallbackGroupID *int64 FallbackGroupID *int64
@@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
} }
} }
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格360/540
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
switch imageSize {
case "360":
return g.SoraImagePrice360
case "540":
return g.SoraImagePrice540
default:
return g.SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions. // IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool { func IsGroupContextValid(group *Group) bool {
if group == nil { if group == nil {

View File

@@ -41,8 +41,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account") return "", errors.New("not an openai/sora oauth account")
} }
cacheKey := OpenAITokenCacheKey(account) cacheKey := OpenAITokenCacheKey(account)
@@ -157,7 +157,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
} }
accessToken := account.GetOpenAIAccessToken() accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" { if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }

View File

@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account) token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account") require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Empty(t, token) require.Empty(t, token)
} }
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account) token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account") require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Empty(t, token) require.Empty(t, token)
} }

View File

@@ -0,0 +1,355 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// Sora2APIModel represents a model entry returned by sora2api.
type Sora2APIModel struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by,omitempty"`
Description string `json:"description,omitempty"`
}
// Sora2APIModelList represents /v1/models response.
type Sora2APIModelList struct {
Object string `json:"object"`
Data []Sora2APIModel `json:"data"`
}
// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem.
type Sora2APIImportTokenItem struct {
Email string `json:"email"`
AccessToken string `json:"access_token,omitempty"`
SessionToken string `json:"session_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
Remark string `json:"remark,omitempty"`
IsActive bool `json:"is_active"`
ImageEnabled bool `json:"image_enabled"`
VideoEnabled bool `json:"video_enabled"`
ImageConcurrency int `json:"image_concurrency"`
VideoConcurrency int `json:"video_concurrency"`
}
// Sora2APIToken represents minimal fields for admin list.
type Sora2APIToken struct {
ID int64 `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Remark string `json:"remark"`
}
// Sora2APIService provides access to sora2api endpoints.
type Sora2APIService struct {
cfg *config.Config
baseURL string
apiKey string
adminUsername string
adminPassword string
adminTokenTTL time.Duration
adminTimeout time.Duration
tokenImportMode string
client *http.Client
adminClient *http.Client
adminToken string
adminTokenAt time.Time
adminMu sync.Mutex
modelCache []Sora2APIModel
modelCacheAt time.Time
modelMu sync.RWMutex
}
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
if cfg == nil {
return &Sora2APIService{}
}
adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second
if adminTTL <= 0 {
adminTTL = 15 * time.Minute
}
adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second
if adminTimeout <= 0 {
adminTimeout = 10 * time.Second
}
return &Sora2APIService{
cfg: cfg,
baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"),
apiKey: strings.TrimSpace(cfg.Sora2API.APIKey),
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
adminTokenTTL: adminTTL,
adminTimeout: adminTimeout,
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
client: &http.Client{},
adminClient: &http.Client{Timeout: adminTimeout},
}
}
func (s *Sora2APIService) Enabled() bool {
return s != nil && s.baseURL != "" && s.apiKey != ""
}
func (s *Sora2APIService) AdminEnabled() bool {
return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != ""
}
func (s *Sora2APIService) buildURL(path string) string {
if s.baseURL == "" {
return path
}
if strings.HasPrefix(path, "/") {
return s.baseURL + path
}
return s.baseURL + "/" + path
}
// BuildURL 返回完整的 sora2api URL用于代理媒体
func (s *Sora2APIService) BuildURL(path string) string {
return s.buildURL(path)
}
func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
resp, err := s.client.Do(req)
if err != nil {
return s.cachedModelsOnError(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode))
}
var payload Sora2APIModelList
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return s.cachedModelsOnError(err)
}
models := payload.Data
if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := make([]Sora2APIModel, 0, len(models))
for _, m := range models {
if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, m)
}
models = filtered
}
s.modelMu.Lock()
s.modelCache = models
s.modelCacheAt = time.Now()
s.modelMu.Unlock()
return models, nil
}
func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) {
s.modelMu.RLock()
cached := append([]Sora2APIModel(nil), s.modelCache...)
s.modelMu.RUnlock()
if len(cached) > 0 {
log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err)
return cached, nil
}
return nil, err
}
func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
mode := s.tokenImportMode
if mode == "" {
mode = "at"
}
payload := map[string]any{
"tokens": items,
"mode": mode,
}
_, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil)
return err
}
func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
var tokens []Sora2APIToken
_, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens)
return tokens, err
}
func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d/disable", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil)
return err
}
func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil)
return err
}
func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
token, err := s.getAdminToken(ctx)
if err != nil {
return nil, err
}
resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out)
if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
s.invalidateAdminToken()
token, err = s.getAdminToken(ctx)
if err != nil {
return resp, err
}
return s.doAdminRequestWithToken(ctx, method, path, token, body, out)
}
return resp, err
}
func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) {
var reader *bytes.Reader
if body != nil {
buf, err := json.Marshal(body)
if err != nil {
return nil, err
}
reader = bytes.NewReader(buf)
} else {
reader = bytes.NewReader(nil)
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := s.adminClient.Do(req)
if err != nil {
return resp, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode)
}
if out != nil {
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return resp, err
}
}
return resp, nil
}
func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) {
s.adminMu.Lock()
defer s.adminMu.Unlock()
if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL {
return s.adminToken, nil
}
if !s.AdminEnabled() {
return "", errors.New("sora2api admin not configured")
}
payload := map[string]string{
"username": s.adminUsername,
"password": s.adminPassword,
}
buf, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.adminClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode)
}
var result struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if !result.Success || result.Token == "" {
if result.Message == "" {
result.Message = "sora2api login failed"
}
return "", errors.New(result.Message)
}
s.adminToken = result.Token
s.adminTokenAt = time.Now()
return result.Token, nil
}
func (s *Sora2APIService) invalidateAdminToken() {
s.adminMu.Lock()
defer s.adminMu.Unlock()
s.adminToken = ""
s.adminTokenAt = time.Time{}
}

View File

@@ -0,0 +1,255 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池
type Sora2APISyncService struct {
sora2api *Sora2APIService
accountRepo AccountRepository
httpClient *http.Client
}
func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService {
return &Sora2APISyncService{
sora2api: sora2api,
accountRepo: accountRepo,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
}
func (s *Sora2APISyncService) Enabled() bool {
return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled()
}
// SyncAccount 将 Sora 账号同步到 sora2api导入或更新
func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
return errors.New("sora 账号缺少 access_token")
}
email, updated := s.resolveAccountEmail(ctx, account)
if email == "" {
return errors.New("无法解析 Sora 账号邮箱")
}
if updated && s.accountRepo != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err)
}
}
item := Sora2APIImportTokenItem{
Email: email,
AccessToken: accessToken,
SessionToken: strings.TrimSpace(account.GetCredential("session_token")),
RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")),
ClientID: strings.TrimSpace(account.GetCredential("client_id")),
Remark: account.Name,
IsActive: account.IsActive() && account.Schedulable,
ImageEnabled: true,
VideoEnabled: true,
ImageConcurrency: normalizeSoraConcurrency(account.Concurrency),
VideoConcurrency: normalizeSoraConcurrency(account.Concurrency),
}
if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil {
return err
}
return nil
}
// DisableAccount 禁用 sora2api 中的 token
func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DisableToken(ctx, tokenID)
}
// DeleteAccount 删除 sora2api 中的 token
func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DeleteToken(ctx, tokenID)
}
func normalizeSoraConcurrency(value int) int {
if value <= 0 {
return -1
}
return value
}
func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) {
if account == nil {
return "", false
}
if email := strings.TrimSpace(account.GetCredential("email")); email != "" {
return email, false
}
if email := strings.TrimSpace(account.GetExtraString("email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken != "" {
if email := extractEmailFromAccessToken(accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := s.fetchEmailFromSora(ctx, accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
}
return "", false
}
func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) {
if account == nil {
return 0, errors.New("account is nil")
}
if account.Extra != nil {
if v, ok := account.Extra["sora2api_token_id"]; ok {
if id, ok := v.(float64); ok && id > 0 {
return int64(id), nil
}
if id, ok := v.(int64); ok && id > 0 {
return id, nil
}
if id, ok := v.(int); ok && id > 0 {
return int64(id), nil
}
}
}
email := strings.TrimSpace(account.GetCredential("email"))
if email == "" {
email, _ = s.resolveAccountEmail(ctx, account)
}
if email == "" {
return 0, errors.New("sora2api token email missing")
}
tokenID, err := s.findTokenIDByEmail(ctx, email)
if err != nil {
return 0, err
}
return tokenID, nil
}
func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) {
if !s.Enabled() {
return 0, errors.New("sora2api admin not configured")
}
tokens, err := s.sora2api.ListTokens(ctx)
if err != nil {
return 0, err
}
for _, token := range tokens {
if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) {
return token.ID, nil
}
}
return 0, fmt.Errorf("sora2api token not found for email: %s", email)
}
func extractEmailFromAccessToken(accessToken string) string {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(accessToken, claims)
if err != nil {
return ""
}
if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok {
if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
}
return ""
}
func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string {
if s.httpClient == nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil)
if err != nil {
return ""
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return ""
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return ""
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return ""
}
if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
return ""
}

View File

@@ -0,0 +1,660 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
"gpt-image-landscape": "540",
"gpt-image-portrait": "540",
}
type soraStreamingResult struct {
content string
mediaType string
mediaURLs []string
imageCount int
imageSize string
firstTokenMs *int
}
// SoraGatewayService handles forwarding requests to sora2api.
type SoraGatewayService struct {
sora2api *Sora2APIService
httpUpstream HTTPUpstream
rateLimitService *RateLimitService
cfg *config.Config
}
func NewSoraGatewayService(
sora2api *Sora2APIService,
httpUpstream HTTPUpstream,
rateLimitService *RateLimitService,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
sora2api: sora2api,
httpUpstream: httpUpstream,
rateLimitService: rateLimitService,
cfg: cfg,
}
}
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
if s.sora2api == nil || !s.sora2api.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "sora2api 未配置",
},
})
}
return nil, errors.New("sora2api not configured")
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel && mappedModel != "" {
reqBody["model"] = mappedModel
if updated, err := json.Marshal(reqBody); err == nil {
body = updated
}
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body)
if err != nil {
return nil, err
}
if c != nil {
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
upstreamReq.Header.Set("User-Agent", ua)
}
}
if reqStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
proxyURL := ""
if account != nil && account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
var resp *http.Response
if s.httpUpstream != nil {
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
} else {
resp, err = http.DefaultClient.Do(upstreamReq)
}
if err != nil {
s.setUpstreamRequestError(c, account, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
}
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream)
if err != nil {
return nil, err
}
result := &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: streamResult.firstTokenMs,
Usage: ClaudeUsage{},
MediaType: streamResult.mediaType,
MediaURL: firstMediaURL(streamResult.mediaURLs),
ImageCount: streamResult.imageCount,
ImageSize: streamResult.imageSize,
}
return result, nil
}
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if s == nil || s.cfg == nil {
return ctx, nil
}
timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds
if stream {
timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds
}
if timeoutSeconds <= 0 {
return ctx, nil
}
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if c != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
}
}
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
if s.rateLimitService == nil || account == nil || resp == nil {
return
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
upstreamMsg = msg
}
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if c != nil {
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
c.JSON(resp.StatusCode, responsePayload)
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
if len(respBody) > 0 {
var payload map[string]any
if err := json.Unmarshal(respBody, &payload); err == nil {
if errObj, ok := payload["error"].(map[string]any); ok {
if overrideMessage != "" {
errObj["message"] = overrideMessage
}
payload["error"] = errObj
return payload
}
}
}
return map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": overrideMessage,
},
}
}
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
if resp == nil {
return nil, errors.New("empty response")
}
if clientStream {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
}
w := c.Writer
flusher, _ := w.(http.Flusher)
contentBuilder := strings.Builder{}
var firstTokenMs *int
var upstreamError error
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
sendLine := func(line string) error {
if !clientStream {
return nil
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return err
}
if flusher != nil {
flusher.Flush()
}
return nil
}
for scanner.Scan() {
line := scanner.Text()
if soraSSEDataRe.MatchString(line) {
data := soraSSEDataRe.ReplaceAllString(line, "")
if data == "[DONE]" {
if err := sendLine("data: [DONE]"); err != nil {
return nil, err
}
break
}
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
if errEvent != nil && upstreamError == nil {
upstreamError = errEvent
}
if contentDelta != "" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
contentBuilder.WriteString(contentDelta)
}
if err := sendLine(updatedLine); err != nil {
return nil, err
}
continue
}
if err := sendLine(line); err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
content := contentBuilder.String()
mediaType, mediaURLs := s.extractSoraMedia(content)
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
mediaType = "prompt"
}
imageSize := ""
imageCount := 0
if mediaType == "image" {
imageSize = soraImageSizeFromModel(originalModel)
imageCount = len(mediaURLs)
}
if upstreamError != nil && !clientStream {
if c != nil {
c.JSON(http.StatusBadGateway, map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": upstreamError.Error(),
},
})
}
return nil, upstreamError
}
if !clientStream {
response := buildSoraNonStreamResponse(content, originalModel)
if len(mediaURLs) > 0 {
response["media_url"] = mediaURLs[0]
if len(mediaURLs) > 1 {
response["media_urls"] = mediaURLs
}
}
c.JSON(http.StatusOK, response)
}
return &soraStreamingResult{
content: content,
mediaType: mediaType,
mediaURLs: mediaURLs,
imageCount: imageCount,
imageSize: imageSize,
firstTokenMs: firstTokenMs,
}, nil
}
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
if strings.TrimSpace(data) == "" {
return "data: ", "", nil
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return "data: " + data, "", nil
}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return "data: " + data, "", errors.New(msg)
}
}
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
payload["model"] = originalModel
}
contentDelta, updated := extractSoraContent(payload)
if updated {
rewritten := s.rewriteSoraContent(contentDelta)
if rewritten != contentDelta {
applySoraContent(payload, rewritten)
contentDelta = rewritten
}
}
updatedData, err := json.Marshal(payload)
if err != nil {
return "data: " + data, contentDelta, nil
}
return "data: " + string(updatedData), contentDelta, nil
}
func extractSoraContent(payload map[string]any) (string, bool) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return "", false
}
choice, ok := choices[0].(map[string]any)
if !ok {
return "", false
}
if delta, ok := choice["delta"].(map[string]any); ok {
if content, ok := delta["content"].(string); ok {
return content, true
}
}
if message, ok := choice["message"].(map[string]any); ok {
if content, ok := message["content"].(string); ok {
return content, true
}
}
return "", false
}
func applySoraContent(payload map[string]any, content string) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return
}
choice, ok := choices[0].(map[string]any)
if !ok {
return
}
if delta, ok := choice["delta"].(map[string]any); ok {
delta["content"] = content
choice["delta"] = delta
return
}
if message, ok := choice["message"].(map[string]any); ok {
message["content"] = content
choice["message"] = message
}
}
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
if content == "" {
return content
}
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraImageMarkdownRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraVideoHTMLRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
return content
}
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
path := parsed.Path
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
return raw
}
return s.buildSoraMediaURL(path, parsed.RawQuery)
}
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
if content == "" {
return "", nil
}
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
return "video", []string{match[1]}
}
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
if len(imageMatches) == 0 {
return "", nil
}
urls := make([]string, 0, len(imageMatches))
for _, match := range imageMatches {
if len(match) > 1 {
urls = append(urls, match[1])
}
}
return "image", urls
}
func buildSoraNonStreamResponse(content, model string) map[string]any {
return map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
},
},
}
}
func soraImageSizeFromModel(model string) string {
modelLower := strings.ToLower(model)
if size, ok := soraImageSizeMap[modelLower]; ok {
return size
}
if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") {
return "540"
}
return "360"
}
func isSoraPromptEnhanceModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
}
func soraProErrorMessage(model, upstreamMsg string) string {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
}
if strings.Contains(modelLower, "sora2pro") {
return "当前账号无法使用 Sora Pro 模型,请更换模型或账号"
}
return ""
}
func firstMediaURL(urls []string) string {
if len(urls) == 0 {
return ""
}
return urls[0]
}
func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string {
if path == "" {
return path
}
prefix := "/sora/media"
values := url.Values{}
if rawQuery != "" {
if parsed, err := url.ParseQuery(rawQuery); err == nil {
values = parsed
}
}
signKey := ""
ttlSeconds := 0
if s != nil && s.cfg != nil {
signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey)
ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds
}
values.Del("sig")
values.Del("expires")
signingQuery := values.Encode()
if signKey != "" && ttlSeconds > 0 {
expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix()
signature := SignSoraMediaURL(path, signingQuery, expires, signKey)
if signature != "" {
values.Set("expires", strconv.FormatInt(expires, 10))
values.Set("sig", signature)
prefix = "/sora/media-signed"
}
}
encoded := values.Encode()
if encoded == "" {
return prefix + path
}
return prefix + path + "?" + encoded
}

View File

@@ -0,0 +1,42 @@
package service
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func SignSoraMediaURL(path string, query string, expires int64, key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(buildSoraMediaSignPayload(path, query)))
mac.Write([]byte("|"))
mac.Write([]byte(strconv.FormatInt(expires, 10)))
return hex.EncodeToString(mac.Sum(nil))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool {
signature = strings.TrimSpace(signature)
if signature == "" {
return false
}
expected := SignSoraMediaURL(path, query, expires, key)
if expected == "" {
return false
}
return hmac.Equal([]byte(signature), []byte(expected))
}
func buildSoraMediaSignPayload(path string, query string) string {
if strings.TrimSpace(query) == "" {
return path
}
return path + "?" + query
}

View File

@@ -0,0 +1,34 @@
package service
import "testing"
func TestSoraMediaSignVerify(t *testing.T) {
key := "test-key"
path := "/tmp/abc.png"
query := "a=1&b=2"
expires := int64(1700000000)
signature := SignSoraMediaURL(path, query, expires, key)
if signature == "" {
t.Fatal("签名为空")
}
if !VerifySoraMediaURL(path, query, expires, signature, key) {
t.Fatal("签名校验失败")
}
if VerifySoraMediaURL(path, "a=1", expires, signature, key) {
t.Fatal("签名参数不同仍然通过")
}
if VerifySoraMediaURL(path, query, expires+1, signature, key) {
t.Fatal("签名过期校验未失败")
}
}
func TestSoraMediaSignWithEmptyKey(t *testing.T) {
signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "")
if signature != "" {
t.Fatalf("空密钥不应生成签名")
}
if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") {
t.Fatalf("空密钥不应通过校验")
}
}

View File

@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键 // Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account)) keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey) keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformOpenAI: case PlatformOpenAI, PlatformSora:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account)) keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic: case PlatformAnthropic:
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account)) keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))

View File

@@ -19,6 +19,7 @@ type TokenRefreshService struct {
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator cacheInvalidator TokenCacheInvalidator
soraSyncService *Sora2APISyncService
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
@@ -65,6 +66,17 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
} }
} }
// SetSoraSyncService 设置 Sora2API 同步服务
// 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
s.soraSyncService = svc
for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraSyncService(svc)
}
}
}
// Start 启动后台刷新服务 // Start 启动后台刷新服务
func (s *TokenRefreshService) Start() { func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled { if !s.cfg.Enabled {

View File

@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
soraSyncService *Sora2APISyncService // Sora2API 同步服务
} }
// NewOpenAITokenRefresher 创建 OpenAI token刷新器 // NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -103,17 +104,22 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo r.soraAccountRepo = repo
} }
// SetSoraSyncService 设置 Sora2API 同步服务
func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) {
r.soraSyncService = svc
}
// CanRefresh 检查是否能处理此账号 // CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号 // 只处理 openai 平台的 oauth 类型账号
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformOpenAI && return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
account.Type == AccountTypeOAuth account.Type == AccountTypeOAuth
} }
// NeedsRefresh 检查token是否需要刷新 // NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内 // 基于 expires_at 字段判断是否在刷新窗口内
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetOpenAITokenExpiresAt() expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil { if expiresAt == nil {
return false return false
} }
@@ -145,6 +151,17 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
} }
// 如果是 Sora 平台账号,同步到 sora2api不阻塞主流程
if account.Platform == PlatformSora && r.soraSyncService != nil {
syncAccount := *account
syncAccount.Credentials = newCredentials
go func() {
if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil {
log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
return newCredentials, nil return newCredentials, nil
} }
@@ -201,6 +218,13 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena
} }
} }
// 2.3 同步到 sora2api如果配置
if r.soraSyncService != nil {
if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil {
log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err)
}
}
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
} }

View File

@@ -46,6 +46,7 @@ type UsageLog struct {
// 图片生成字段 // 图片生成字段
ImageCount int ImageCount int
ImageSize *string ImageSize *string
MediaType *string
CreatedAt time.Time CreatedAt time.Time

View File

@@ -40,6 +40,7 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
func ProvideTokenRefreshService( func ProvideTokenRefreshService(
accountRepo AccountRepository, accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
soraSyncService *Sora2APISyncService,
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
@@ -50,6 +51,9 @@ func ProvideTokenRefreshService(
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo) svc.SetSoraAccountRepo(soraAccountRepo)
if soraSyncService != nil {
svc.SetSoraSyncService(soraSyncService)
}
svc.Start() svc.Start()
return svc return svc
} }
@@ -224,6 +228,7 @@ var ProviderSet = wire.NewSet(
NewBillingCacheService, NewBillingCacheService,
NewAdminService, NewAdminService,
NewGatewayService, NewGatewayService,
NewSoraGatewayService,
NewOpenAIGatewayService, NewOpenAIGatewayService,
NewOAuthService, NewOAuthService,
NewOpenAIOAuthService, NewOpenAIOAuthService,
@@ -237,6 +242,8 @@ var ProviderSet = wire.NewSet(
NewAntigravityTokenProvider, NewAntigravityTokenProvider,
NewOpenAITokenProvider, NewOpenAITokenProvider,
NewClaudeTokenProvider, NewClaudeTokenProvider,
NewSora2APIService,
NewSora2APISyncService,
NewAntigravityGatewayService, NewAntigravityGatewayService,
ProvideRateLimitService, ProvideRateLimitService,
NewAccountUsageService, NewAccountUsageService,

View File

@@ -0,0 +1,11 @@
-- Migration: 047_add_sora_pricing_and_media_type
-- 新增 Sora 按次计费字段与 usage_logs.media_type
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS sora_image_price_360 decimal(20,8),
ADD COLUMN IF NOT EXISTS sora_image_price_540 decimal(20,8),
ADD COLUMN IF NOT EXISTS sora_video_price_per_request decimal(20,8),
ADD COLUMN IF NOT EXISTS sora_video_price_per_request_hd decimal(20,8);
ALTER TABLE usage_logs
ADD COLUMN IF NOT EXISTS media_type VARCHAR(16);

View File

@@ -1,39 +1,6 @@
# =============================================================================
# Sub2API Caddy Reverse Proxy Configuration (宿主机部署)
# =============================================================================
# 使用方法:
# 1. 安装 Caddy: https://caddyserver.com/docs/install
# 2. 修改下方 example.com 为你的域名
# 3. 确保域名 DNS 已指向服务器
# 4. 复制配置: sudo cp Caddyfile /etc/caddy/Caddyfile
# 5. 重载配置: sudo systemctl reload caddy
#
# Caddy 会自动申请和续期 Let's Encrypt SSL 证书
# =============================================================================
# 全局配置
{
# Let's Encrypt 邮箱通知
email admin@example.com
# 服务器配置
servers {
# 启用 HTTP/2 和 HTTP/3
protocols h1 h2 h3
# 超时配置
timeouts {
read_body 30s
read_header 10s
write 300s
idle 300s
}
}
}
# 修改为你的域名 # 修改为你的域名
example.com { api.sub2api.com {
# ========================================================================= # =========================================================================
# 静态资源长期缓存(高优先级,放在最前面) # 静态资源长期缓存(高优先级,放在最前面)
# 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存 # 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存
# ========================================================================= # =========================================================================
@@ -87,17 +54,13 @@ example.com {
# 连接池优化 # 连接池优化
transport http { transport http {
versions h2c h1
keepalive 120s keepalive 120s
keepalive_idle_conns 256 keepalive_idle_conns 256
read_buffer 16KB read_buffer 16KB
write_buffer 16KB write_buffer 16KB
compression off compression off
} }
# SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端
flush_interval -1
# 故障转移 # 故障转移
fail_duration 30s fail_duration 30s
max_fails 3 max_fails 3
@@ -112,10 +75,6 @@ example.com {
gzip 6 gzip 6
minimum_length 256 minimum_length 256
match { match {
# SSE 请求通常会带 Accept: text/event-stream需排除压缩
not header Accept text/event-stream*
# 排除已知 SSE 路径(即便 Accept 缺失)
not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/*
header Content-Type text/* header Content-Type text/*
header Content-Type application/json* header Content-Type application/json*
header Content-Type application/javascript* header Content-Type application/javascript*
@@ -199,7 +158,3 @@ example.com {
respond "{err.status_code} {err.status_text}" respond "{err.status_code} {err.status_text}"
} }
} }
# =============================================================================
# HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明)
# =============================================================================

View File

@@ -116,6 +116,33 @@ gateway:
# Max request body size in bytes (default: 100MB) # Max request body size in bytes (default: 100MB)
# 请求体最大字节数(默认 100MB # 请求体最大字节数(默认 100MB
max_body_size: 104857600 max_body_size: 104857600
# Sora max request body size in bytes (0=use max_body_size)
# Sora 请求体最大字节数0=使用 max_body_size
sora_max_body_size: 268435456
# Sora stream timeout (seconds, 0=disable)
# Sora 流式请求总超时0=禁用)
sora_stream_timeout_seconds: 900
# Sora non-stream timeout (seconds, 0=disable)
# Sora 非流式请求超时0=禁用)
sora_request_timeout_seconds: 180
# Sora stream enforcement mode: force/error
# Sora stream 强制策略force/error
sora_stream_mode: "force"
# Sora model filters
# Sora 模型过滤配置
sora_model_filters:
# Hide prompt-enhance models by default
# 默认隐藏 prompt-enhance 模型
hide_prompt_enhance: true
# Require API key for /sora/media proxy (default: false)
# /sora/media 是否强制要求 API Key默认 true
sora_media_require_api_key: true
# Sora media temporary signing key (empty disables signed URL)
# Sora 媒体临时签名密钥(为空则禁用签名)
sora_media_signing_key: ""
# Signed URL TTL seconds (<=0 disables)
# 临时签名 URL 有效期(秒,<=0 表示禁用)
sora_media_signed_url_ttl_seconds: 900
# Connection pool isolation strategy: # Connection pool isolation strategy:
# 连接池隔离策略: # 连接池隔离策略:
# - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts)
@@ -220,6 +247,31 @@ gateway:
# name: "Custom Profile 1" # name: "Custom Profile 1"
# profile_2: # profile_2:
# name: "Custom Profile 2" # name: "Custom Profile 2"
# =============================================================================
# Sora2API Configuration
# Sora2API 配置
# =============================================================================
sora2api:
# Sora2API base URL
# Sora2API 服务地址
base_url: "http://127.0.0.1:8000"
# Sora2API API Key (for /v1/chat/completions and /v1/models)
# Sora2API API Key用于生成/模型列表)
api_key: ""
# Admin username/password (for token sync)
# 管理口用户名/密码(用于 token 同步)
admin_username: "admin"
admin_password: "admin"
# Admin token cache ttl (seconds)
# 管理口 token 缓存时长(秒)
admin_token_ttl_seconds: 900
# Admin request timeout (seconds)
# 管理口请求超时(秒)
admin_timeout_seconds: 10
# Token import mode: at/offline
# Token 导入模式at/offline
token_import_mode: "at"
# cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
# curves: [29, 23, 24] # curves: [29, 23, 24]
# point_formats: [0] # point_formats: [0]

View File

@@ -18,6 +18,7 @@ import geminiAPI from './gemini'
import antigravityAPI from './antigravity' import antigravityAPI from './antigravity'
import userAttributesAPI from './userAttributes' import userAttributesAPI from './userAttributes'
import opsAPI from './ops' import opsAPI from './ops'
import modelsAPI from './models'
/** /**
* Unified admin API object for convenient access * Unified admin API object for convenient access
@@ -37,7 +38,8 @@ export const adminAPI = {
gemini: geminiAPI, gemini: geminiAPI,
antigravity: antigravityAPI, antigravity: antigravityAPI,
userAttributes: userAttributesAPI, userAttributes: userAttributesAPI,
ops: opsAPI ops: opsAPI,
models: modelsAPI
} }
export { export {
@@ -55,7 +57,8 @@ export {
geminiAPI, geminiAPI,
antigravityAPI, antigravityAPI,
userAttributesAPI, userAttributesAPI,
opsAPI opsAPI,
modelsAPI
} }
export default adminAPI export default adminAPI

View File

@@ -0,0 +1,14 @@
import { apiClient } from '@/api/client'
export async function getPlatformModels(platform: string): Promise<string[]> {
const { data } = await apiClient.get<string[]>('/admin/models', {
params: { platform }
})
return data
}
export const modelsAPI = {
getPlatformModels
}
export default modelsAPI

View File

@@ -45,6 +45,19 @@
:placeholder="t('admin.accounts.searchModels')" :placeholder="t('admin.accounts.searchModels')"
@click.stop @click.stop
/> />
<div v-if="props.platform === 'sora'" class="mt-2 flex items-center gap-2 text-xs">
<span v-if="loadingSoraModels" class="text-gray-500">
{{ t('admin.accounts.soraModelsLoading') }}
</span>
<button
v-else-if="soraLoadError"
type="button"
class="text-primary-600 hover:underline dark:text-primary-400"
@click.stop="loadSoraModels"
>
{{ t('admin.accounts.soraModelsRetry') }}
</button>
</div>
</div> </div>
<div class="max-h-52 overflow-auto"> <div class="max-h-52 overflow-auto">
<button <button
@@ -120,12 +133,13 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import ModelIcon from '@/components/common/ModelIcon.vue' import ModelIcon from '@/components/common/ModelIcon.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { allModels, getModelsByPlatform } from '@/composables/useModelWhitelist' import { allModels, getModelsByPlatform } from '@/composables/useModelWhitelist'
import { adminAPI } from '@/api/admin'
const { t } = useI18n() const { t } = useI18n()
@@ -144,11 +158,24 @@ const showDropdown = ref(false)
const searchQuery = ref('') const searchQuery = ref('')
const customModel = ref('') const customModel = ref('')
const isComposing = ref(false) const isComposing = ref(false)
const soraModelOptions = ref<{ value: string; label: string }[]>([])
const loadingSoraModels = ref(false)
const soraLoadError = ref(false)
const availableOptions = computed(() => {
if (props.platform === 'sora') {
if (soraModelOptions.value.length > 0) {
return soraModelOptions.value
}
return getModelsByPlatform('sora').map(m => ({ value: m, label: m }))
}
return allModels
})
const filteredModels = computed(() => { const filteredModels = computed(() => {
const query = searchQuery.value.toLowerCase().trim() const query = searchQuery.value.toLowerCase().trim()
if (!query) return allModels if (!query) return availableOptions.value
return allModels.filter( return availableOptions.value.filter(
m => m.value.toLowerCase().includes(query) || m.label.toLowerCase().includes(query) m => m.value.toLowerCase().includes(query) || m.label.toLowerCase().includes(query)
) )
}) })
@@ -186,7 +213,9 @@ const handleEnter = () => {
} }
const fillRelated = () => { const fillRelated = () => {
const models = getModelsByPlatform(props.platform) const models = props.platform === 'sora' && soraModelOptions.value.length > 0
? soraModelOptions.value.map(m => m.value)
: getModelsByPlatform(props.platform)
const newModels = [...props.modelValue] const newModels = [...props.modelValue]
for (const model of models) { for (const model of models) {
if (!newModels.includes(model)) newModels.push(model) if (!newModels.includes(model)) newModels.push(model)
@@ -197,4 +226,32 @@ const fillRelated = () => {
const clearAll = () => { const clearAll = () => {
emit('update:modelValue', []) emit('update:modelValue', [])
} }
const loadSoraModels = async () => {
if (props.platform !== 'sora') {
soraModelOptions.value = []
return
}
if (loadingSoraModels.value) return
soraLoadError.value = false
loadingSoraModels.value = true
try {
const models = await adminAPI.models.getPlatformModels('sora')
soraModelOptions.value = (models || []).map((m) => ({ value: m, label: m }))
} catch (error) {
console.warn('加载 Sora 模型列表失败', error)
soraLoadError.value = true
appStore.showWarning(t('admin.accounts.soraModelsLoadFailed'))
} finally {
loadingSoraModels.value = false
}
}
watch(
() => props.platform,
() => {
loadSoraModels()
},
{ immediate: true }
)
</script> </script>

View File

@@ -19,7 +19,7 @@ const props = defineProps(['searchQuery', 'filters']); const emit = defineEmits(
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) } const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) } const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) } const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }]) const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }]) const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }])
</script> </script>

View File

@@ -97,6 +97,9 @@ const labelClass = computed(() => {
if (props.platform === 'gemini') { if (props.platform === 'gemini') {
return `${base} bg-blue-200/60 text-blue-800 dark:bg-blue-800/40 dark:text-blue-300` return `${base} bg-blue-200/60 text-blue-800 dark:bg-blue-800/40 dark:text-blue-300`
} }
if (props.platform === 'sora') {
return `${base} bg-rose-200/60 text-rose-800 dark:bg-rose-800/40 dark:text-rose-300`
}
return `${base} bg-violet-200/60 text-violet-800 dark:bg-violet-800/40 dark:text-violet-300` return `${base} bg-violet-200/60 text-violet-800 dark:bg-violet-800/40 dark:text-violet-300`
}) })
@@ -118,6 +121,11 @@ const badgeClass = computed(() => {
? 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' ? 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
: 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400' : 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
} }
if (props.platform === 'sora') {
return isSubscription.value
? 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400'
: 'bg-rose-50 text-rose-700 dark:bg-rose-900/20 dark:text-rose-400'
}
// Fallback: original colors // Fallback: original colors
return isSubscription.value return isSubscription.value
? 'bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-400' ? 'bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-400'

View File

@@ -19,6 +19,12 @@
<svg v-else-if="platform === 'antigravity'" :class="sizeClass" viewBox="0 0 24 24" fill="currentColor"> <svg v-else-if="platform === 'antigravity'" :class="sizeClass" viewBox="0 0 24 24" fill="currentColor">
<path d="M19.35 10.04C18.67 6.59 15.64 4 12 4 9.11 4 6.6 5.64 5.35 8.04 2.34 8.36 0 10.91 0 14c0 3.31 2.69 6 6 6h13c2.76 0 5-2.24 5-5 0-2.64-2.05-4.78-4.65-4.96z" /> <path d="M19.35 10.04C18.67 6.59 15.64 4 12 4 9.11 4 6.6 5.64 5.35 8.04 2.34 8.36 0 10.91 0 14c0 3.31 2.69 6 6 6h13c2.76 0 5-2.24 5-5 0-2.64-2.05-4.78-4.65-4.96z" />
</svg> </svg>
<!-- Sora logo (sparkle) -->
<svg v-else-if="platform === 'sora'" :class="sizeClass" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2.5l2.1 4.7 5.1.5-3.9 3.4 1.2 5-4.5-2.6-4.5 2.6 1.2-5-3.9-3.4 5.1-.5L12 2.5z"
/>
</svg>
<!-- Fallback: generic platform icon --> <!-- Fallback: generic platform icon -->
<svg v-else :class="sizeClass" fill="currentColor" viewBox="0 0 24 24"> <svg v-else :class="sizeClass" fill="currentColor" viewBox="0 0 24 24">
<path <path

View File

@@ -48,6 +48,7 @@ const platformLabel = computed(() => {
if (props.platform === 'anthropic') return 'Anthropic' if (props.platform === 'anthropic') return 'Anthropic'
if (props.platform === 'openai') return 'OpenAI' if (props.platform === 'openai') return 'OpenAI'
if (props.platform === 'antigravity') return 'Antigravity' if (props.platform === 'antigravity') return 'Antigravity'
if (props.platform === 'sora') return 'Sora'
return 'Gemini' return 'Gemini'
}) })
@@ -74,6 +75,9 @@ const platformClass = computed(() => {
if (props.platform === 'antigravity') { if (props.platform === 'antigravity') {
return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400' return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
} }
if (props.platform === 'sora') {
return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400'
}
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
}) })
@@ -87,6 +91,9 @@ const typeClass = computed(() => {
if (props.platform === 'antigravity') { if (props.platform === 'antigravity') {
return 'bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400' return 'bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400'
} }
if (props.platform === 'sora') {
return 'bg-rose-100 text-rose-600 dark:bg-rose-900/30 dark:text-rose-400'
}
return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400' return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400'
}) })
</script> </script>

View File

@@ -52,6 +52,22 @@ const geminiModels = [
'gemini-3-pro-preview' 'gemini-3-pro-preview'
] ]
// Sora (sora2api)
const soraModels = [
'gpt-image', 'gpt-image-landscape', 'gpt-image-portrait',
'sora2-landscape-10s', 'sora2-portrait-10s',
'sora2-landscape-15s', 'sora2-portrait-15s',
'sora2-landscape-25s', 'sora2-portrait-25s',
'sora2pro-landscape-10s', 'sora2pro-portrait-10s',
'sora2pro-landscape-15s', 'sora2pro-portrait-15s',
'sora2pro-landscape-25s', 'sora2pro-portrait-25s',
'sora2pro-hd-landscape-10s', 'sora2pro-hd-portrait-10s',
'sora2pro-hd-landscape-15s', 'sora2pro-hd-portrait-15s',
'prompt-enhance-short-10s', 'prompt-enhance-short-15s', 'prompt-enhance-short-20s',
'prompt-enhance-medium-10s', 'prompt-enhance-medium-15s', 'prompt-enhance-medium-20s',
'prompt-enhance-long-10s', 'prompt-enhance-long-15s', 'prompt-enhance-long-20s'
]
// 智谱 GLM // 智谱 GLM
const zhipuModels = [ const zhipuModels = [
'glm-4', 'glm-4v', 'glm-4-plus', 'glm-4-0520', 'glm-4', 'glm-4v', 'glm-4-plus', 'glm-4-0520',
@@ -182,6 +198,7 @@ const allModelsList: string[] = [
...openaiModels, ...openaiModels,
...claudeModels, ...claudeModels,
...geminiModels, ...geminiModels,
...soraModels,
...zhipuModels, ...zhipuModels,
...qwenModels, ...qwenModels,
...deepseekModels, ...deepseekModels,
@@ -227,6 +244,8 @@ const openaiPresetMappings = [
{ label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' } { label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }
] ]
const soraPresetMappings: { label: string; from: string; to: string; color: string }[] = []
const geminiPresetMappings = [ const geminiPresetMappings = [
{ label: 'Flash 2.0', from: 'gemini-2.0-flash', to: 'gemini-2.0-flash', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' }, { label: 'Flash 2.0', from: 'gemini-2.0-flash', to: 'gemini-2.0-flash', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
{ label: '2.5 Flash', from: 'gemini-2.5-flash', to: 'gemini-2.5-flash', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: '2.5 Flash', from: 'gemini-2.5-flash', to: 'gemini-2.5-flash', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
@@ -258,6 +277,7 @@ export function getModelsByPlatform(platform: string): string[] {
case 'anthropic': case 'anthropic':
case 'claude': return claudeModels case 'claude': return claudeModels
case 'gemini': return geminiModels case 'gemini': return geminiModels
case 'sora': return soraModels
case 'zhipu': return zhipuModels case 'zhipu': return zhipuModels
case 'qwen': return qwenModels case 'qwen': return qwenModels
case 'deepseek': return deepseekModels case 'deepseek': return deepseekModels
@@ -281,6 +301,7 @@ export function getModelsByPlatform(platform: string): string[] {
export function getPresetMappingsByPlatform(platform: string) { export function getPresetMappingsByPlatform(platform: string) {
if (platform === 'openai') return openaiPresetMappings if (platform === 'openai') return openaiPresetMappings
if (platform === 'gemini') return geminiPresetMappings if (platform === 'gemini') return geminiPresetMappings
if (platform === 'sora') return soraPresetMappings
return anthropicPresetMappings return anthropicPresetMappings
} }

View File

@@ -895,7 +895,8 @@ export default {
anthropic: 'Anthropic', anthropic: 'Anthropic',
openai: 'OpenAI', openai: 'OpenAI',
gemini: 'Gemini', gemini: 'Gemini',
antigravity: 'Antigravity' antigravity: 'Antigravity',
sora: 'Sora'
}, },
deleteConfirm: deleteConfirm:
"Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.", "Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.",
@@ -920,6 +921,14 @@ export default {
title: 'Image Generation Pricing', title: 'Image Generation Pricing',
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.' description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
}, },
soraPricing: {
title: 'Sora Per-Request Pricing',
description: 'Configure per-request pricing for Sora image/video generation. Leave empty to disable billing.',
image360: 'Image 360px ($)',
image540: 'Image 540px ($)',
video: 'Video (standard) ($)',
videoHd: 'Video (Pro-HD) ($)'
},
claudeCode: { claudeCode: {
title: 'Claude Code Client Restriction', title: 'Claude Code Client Restriction',
tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.', tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.',
@@ -1079,7 +1088,8 @@ export default {
claude: 'Claude', claude: 'Claude',
openai: 'OpenAI', openai: 'OpenAI',
gemini: 'Gemini', gemini: 'Gemini',
antigravity: 'Antigravity' antigravity: 'Antigravity',
sora: 'Sora'
}, },
types: { types: {
oauth: 'OAuth', oauth: 'OAuth',
@@ -1257,6 +1267,9 @@ export default {
'Map request models to actual models. Left is the requested model, right is the actual model sent to API.', 'Map request models to actual models. Left is the requested model, right is the actual model sent to API.',
selectedModels: 'Selected {count} model(s)', selectedModels: 'Selected {count} model(s)',
supportsAllModels: '(supports all models)', supportsAllModels: '(supports all models)',
soraModelsLoadFailed: 'Failed to load Sora models, fallback to default list',
soraModelsLoading: 'Loading Sora models...',
soraModelsRetry: 'Load failed, click to retry',
requestModel: 'Request model', requestModel: 'Request model',
actualModel: 'Actual model', actualModel: 'Actual model',
addMapping: 'Add Mapping', addMapping: 'Add Mapping',

View File

@@ -941,7 +941,8 @@ export default {
anthropic: 'Anthropic', anthropic: 'Anthropic',
openai: 'OpenAI', openai: 'OpenAI',
gemini: 'Gemini', gemini: 'Gemini',
antigravity: 'Antigravity' antigravity: 'Antigravity',
sora: 'Sora'
}, },
saving: '保存中...', saving: '保存中...',
noGroups: '暂无分组', noGroups: '暂无分组',
@@ -995,6 +996,14 @@ export default {
title: '图片生成计费', title: '图片生成计费',
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格' description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
}, },
soraPricing: {
title: 'Sora 按次计费',
description: '配置 Sora 图片/视频按次收费价格,留空则默认不计费',
image360: '图片 360px ($)',
image540: '图片 540px ($)',
video: '视频(标准)($)',
videoHd: '视频Pro-HD($)'
},
claudeCode: { claudeCode: {
title: 'Claude Code 客户端限制', title: 'Claude Code 客户端限制',
tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。', tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。',
@@ -1199,7 +1208,8 @@ export default {
openai: 'OpenAI', openai: 'OpenAI',
anthropic: 'Anthropic', anthropic: 'Anthropic',
gemini: 'Gemini', gemini: 'Gemini',
antigravity: 'Antigravity' antigravity: 'Antigravity',
sora: 'Sora'
}, },
types: { types: {
oauth: 'OAuth', oauth: 'OAuth',
@@ -1391,6 +1401,9 @@ export default {
mapRequestModels: '将请求模型映射到实际模型。左边是请求的模型,右边是发送到 API 的实际模型。', mapRequestModels: '将请求模型映射到实际模型。左边是请求的模型,右边是发送到 API 的实际模型。',
selectedModels: '已选择 {count} 个模型', selectedModels: '已选择 {count} 个模型',
supportsAllModels: '(支持所有模型)', supportsAllModels: '(支持所有模型)',
soraModelsLoadFailed: '加载 Sora 模型列表失败,已回退到默认列表',
soraModelsLoading: '正在加载 Sora 模型...',
soraModelsRetry: '加载失败,点击重试',
requestModel: '请求模型', requestModel: '请求模型',
actualModel: '实际模型', actualModel: '实际模型',
addMapping: '添加映射', addMapping: '添加映射',

View File

@@ -252,7 +252,7 @@ export interface PaginationConfig {
// ==================== API Key & Group Types ==================== // ==================== API Key & Group Types ====================
export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
export type SubscriptionType = 'standard' | 'subscription' export type SubscriptionType = 'standard' | 'subscription'
@@ -272,6 +272,11 @@ export interface Group {
image_price_1k: number | null image_price_1k: number | null
image_price_2k: number | null image_price_2k: number | null
image_price_4k: number | null image_price_4k: number | null
// Sora 按次计费配置
sora_image_price_360: number | null
sora_image_price_540: number | null
sora_video_price_per_request: number | null
sora_video_price_per_request_hd: number | null
// Claude Code 客户端限制 // Claude Code 客户端限制
claude_code_only: boolean claude_code_only: boolean
fallback_group_id: number | null fallback_group_id: number | null
@@ -331,6 +336,10 @@ export interface CreateGroupRequest {
image_price_1k?: number | null image_price_1k?: number | null
image_price_2k?: number | null image_price_2k?: number | null
image_price_4k?: number | null image_price_4k?: number | null
sora_image_price_360?: number | null
sora_image_price_540?: number | null
sora_video_price_per_request?: number | null
sora_video_price_per_request_hd?: number | null
claude_code_only?: boolean claude_code_only?: boolean
fallback_group_id?: number | null fallback_group_id?: number | null
} }
@@ -349,13 +358,17 @@ export interface UpdateGroupRequest {
image_price_1k?: number | null image_price_1k?: number | null
image_price_2k?: number | null image_price_2k?: number | null
image_price_4k?: number | null image_price_4k?: number | null
sora_image_price_360?: number | null
sora_image_price_540?: number | null
sora_video_price_per_request?: number | null
sora_video_price_per_request_hd?: number | null
claude_code_only?: boolean claude_code_only?: boolean
fallback_group_id?: number | null fallback_group_id?: number | null
} }
// ==================== Account & Proxy Types ==================== // ==================== Account & Proxy Types ====================
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
export type AccountType = 'oauth' | 'setup-token' | 'apikey' export type AccountType = 'oauth' | 'setup-token' | 'apikey'
export type OAuthAddMethod = 'oauth' | 'setup-token' export type OAuthAddMethod = 'oauth' | 'setup-token'
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'

View File

@@ -404,6 +404,64 @@
</div> </div>
</div> </div>
<!-- Sora 按次计费配置 -->
<div v-if="createForm.platform === 'sora'" class="border-t pt-4">
<label class="block mb-2 font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.soraPricing.title') }}
</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
{{ t('admin.groups.soraPricing.description') }}
</p>
<div class="grid grid-cols-2 gap-3 mb-4">
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.image360') }}</label>
<input
v-model.number="createForm.sora_image_price_360"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.05"
/>
</div>
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.image540') }}</label>
<input
v-model.number="createForm.sora_image_price_540"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.08"
/>
</div>
</div>
<div class="grid grid-cols-2 gap-3">
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.video') }}</label>
<input
v-model.number="createForm.sora_video_price_per_request"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.5"
/>
</div>
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.videoHd') }}</label>
<input
v-model.number="createForm.sora_video_price_per_request_hd"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.8"
/>
</div>
</div>
</div>
<!-- Claude Code 客户端限制 anthropic 平台 --> <!-- Claude Code 客户端限制 anthropic 平台 -->
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4"> <div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
<div class="mb-1.5 flex items-center gap-1"> <div class="mb-1.5 flex items-center gap-1">
@@ -848,6 +906,64 @@
</div> </div>
</div> </div>
<!-- Sora 按次计费配置 -->
<div v-if="editForm.platform === 'sora'" class="border-t pt-4">
<label class="block mb-2 font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.soraPricing.title') }}
</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
{{ t('admin.groups.soraPricing.description') }}
</p>
<div class="grid grid-cols-2 gap-3 mb-4">
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.image360') }}</label>
<input
v-model.number="editForm.sora_image_price_360"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.05"
/>
</div>
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.image540') }}</label>
<input
v-model.number="editForm.sora_image_price_540"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.08"
/>
</div>
</div>
<div class="grid grid-cols-2 gap-3">
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.video') }}</label>
<input
v-model.number="editForm.sora_video_price_per_request"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.5"
/>
</div>
<div>
<label class="input-label">{{ t('admin.groups.soraPricing.videoHd') }}</label>
<input
v-model.number="editForm.sora_video_price_per_request_hd"
type="number"
step="0.001"
min="0"
class="input"
placeholder="0.8"
/>
</div>
</div>
</div>
<!-- Claude Code 客户端限制 anthropic 平台 --> <!-- Claude Code 客户端限制 anthropic 平台 -->
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4"> <div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
<div class="mb-1.5 flex items-center gap-1"> <div class="mb-1.5 flex items-center gap-1">
@@ -1152,7 +1268,8 @@ const platformOptions = computed(() => [
{ value: 'anthropic', label: 'Anthropic' }, { value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' }, { value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' }, { value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' } { value: 'antigravity', label: 'Antigravity' },
{ value: 'sora', label: 'Sora' }
]) ])
const platformFilterOptions = computed(() => [ const platformFilterOptions = computed(() => [
@@ -1160,7 +1277,8 @@ const platformFilterOptions = computed(() => [
{ value: 'anthropic', label: 'Anthropic' }, { value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' }, { value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' }, { value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' } { value: 'antigravity', label: 'Antigravity' },
{ value: 'sora', label: 'Sora' }
]) ])
const editStatusOptions = computed(() => [ const editStatusOptions = computed(() => [
@@ -1240,6 +1358,16 @@ const createForm = reactive({
image_price_1k: null as number | null, image_price_1k: null as number | null,
image_price_2k: null as number | null, image_price_2k: null as number | null,
image_price_4k: null as number | null, image_price_4k: null as number | null,
// Sora 按次计费配置
sora_image_price_360: null as number | null,
sora_image_price_540: null as number | null,
sora_video_price_per_request: null as number | null,
sora_video_price_per_request_hd: null as number | null,
// Sora 按次计费配置
sora_image_price_360: null as number | null,
sora_image_price_540: null as number | null,
sora_video_price_per_request: null as number | null,
sora_video_price_per_request_hd: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用) // Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false, claude_code_only: false,
fallback_group_id: null as number | null, fallback_group_id: null as number | null,
@@ -1411,6 +1539,11 @@ const editForm = reactive({
image_price_1k: null as number | null, image_price_1k: null as number | null,
image_price_2k: null as number | null, image_price_2k: null as number | null,
image_price_4k: null as number | null, image_price_4k: null as number | null,
// Sora 按次计费配置
sora_image_price_360: null as number | null,
sora_image_price_540: null as number | null,
sora_video_price_per_request: null as number | null,
sora_video_price_per_request_hd: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用) // Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false, claude_code_only: false,
fallback_group_id: null as number | null, fallback_group_id: null as number | null,
@@ -1495,6 +1628,10 @@ const closeCreateModal = () => {
createForm.image_price_1k = null createForm.image_price_1k = null
createForm.image_price_2k = null createForm.image_price_2k = null
createForm.image_price_4k = null createForm.image_price_4k = null
createForm.sora_image_price_360 = null
createForm.sora_image_price_540 = null
createForm.sora_video_price_per_request = null
createForm.sora_video_price_per_request_hd = null
createForm.claude_code_only = false createForm.claude_code_only = false
createForm.fallback_group_id = null createForm.fallback_group_id = null
createModelRoutingRules.value = [] createModelRoutingRules.value = []
@@ -1544,6 +1681,10 @@ const handleEdit = async (group: AdminGroup) => {
editForm.image_price_1k = group.image_price_1k editForm.image_price_1k = group.image_price_1k
editForm.image_price_2k = group.image_price_2k editForm.image_price_2k = group.image_price_2k
editForm.image_price_4k = group.image_price_4k editForm.image_price_4k = group.image_price_4k
editForm.sora_image_price_360 = group.sora_image_price_360
editForm.sora_image_price_540 = group.sora_image_price_540
editForm.sora_video_price_per_request = group.sora_video_price_per_request
editForm.sora_video_price_per_request_hd = group.sora_video_price_per_request_hd
editForm.claude_code_only = group.claude_code_only || false editForm.claude_code_only = group.claude_code_only || false
editForm.fallback_group_id = group.fallback_group_id editForm.fallback_group_id = group.fallback_group_id
editForm.model_routing_enabled = group.model_routing_enabled || false editForm.model_routing_enabled = group.model_routing_enabled || false