mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
sync: bring over remaining release/custom-0.1.115 changes
- Extract PublicSettingsInjectionPayload named struct with drift test - Add channel_monitor_default_interval_seconds to SSR injection - Add image_output_price to SupportedModelChip - Simplify AppSidebar buildSelfNavItems (admins see available channels) - Add gateway WARN logs for 503 no-available-accounts branches - Wire ChannelMonitorRunner into provideCleanup for graceful shutdown - Add migrations 130/131 (CC template userid fix + mimicry field cleanup) - Clean up fork-only features (sora, claude max simulation, client affinity) - Remove ~320 obsolete i18n keys - Add codexUsage utility, WechatServiceButton, BulkEditAccountModal - Tidy go.sum
This commit is contained in:
@@ -97,6 +97,7 @@ func provideCleanup(
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
paymentOrderExpiry *service.PaymentOrderExpiryService,
|
||||
channelMonitorRunner *service.ChannelMonitorRunner,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -239,6 +240,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ChannelMonitorRunner", func() error {
|
||||
if channelMonitorRunner != nil {
|
||||
channelMonitorRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -222,7 +222,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
_ = channelMonitorRunner
|
||||
registry := payment.ProvideRegistry()
|
||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||
if err != nil {
|
||||
@@ -262,7 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -316,6 +315,7 @@ func provideCleanup(
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
paymentOrderExpiry *service.PaymentOrderExpiryService,
|
||||
channelMonitorRunner *service.ChannelMonitorRunner,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
@@ -76,6 +76,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
nil, // paymentOrderExpiry
|
||||
nil, // channelMonitorRunner
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -53,21 +53,21 @@ type Group struct {
|
||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
// allow Claude Code client only
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
// fallback group for non-Claude-Code requests
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// 无效请求兜底使用的分组 ID
|
||||
// fallback group for invalid request
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||
// model routing config: pattern -> account ids
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
// 是否启用模型路由配置
|
||||
// whether model routing is enabled
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||
// whether MCP XML prompt injection is enabled
|
||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||
// supported model scopes: claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
// group display order, lower comes first
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
|
||||
@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
|
||||
|
||||
func (Group) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
||||
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||
field.String("name").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.StatusActive),
|
||||
|
||||
// Subscription-related fields (added by migration 003)
|
||||
field.String("platform").
|
||||
MaxLen(50).
|
||||
Default(domain.PlatformAnthropic),
|
||||
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("default_validity_days").
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
@@ -90,42 +86,36 @@ func (Group) Fields() []ent.Field {
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
Comment("是否仅允许 Claude Code 客户端"),
|
||||
Comment("allow Claude Code client only"),
|
||||
field.Int64("fallback_group_id").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
Comment("fallback group for non-Claude-Code requests"),
|
||||
field.Int64("fallback_group_id_on_invalid_request").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("无效请求兜底使用的分组 ID"),
|
||||
Comment("fallback group for invalid request"),
|
||||
|
||||
// 模型路由配置 (added by migration 040)
|
||||
field.JSON("model_routing", map[string][]int64{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||
|
||||
// 模型路由开关 (added by migration 041)
|
||||
Comment("model routing config: pattern -> account ids"),
|
||||
field.Bool("model_routing_enabled").
|
||||
Default(false).
|
||||
Comment("是否启用模型路由配置"),
|
||||
Comment("whether model routing is enabled"),
|
||||
|
||||
// MCP XML 协议注入开关 (added by migration 042)
|
||||
field.Bool("mcp_xml_inject").
|
||||
Default(true).
|
||||
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||
Comment("whether MCP XML prompt injection is enabled"),
|
||||
|
||||
// 支持的模型系列 (added by migration 046)
|
||||
field.JSON("supported_model_scopes", []string{}).
|
||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
||||
|
||||
// 分组排序 (added by migration 052)
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
Comment("group display order, lower comes first"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
@@ -160,14 +150,11 @@ func (Group) Edges() []ent.Edge {
|
||||
edge.From("allowed_users", User.Type).
|
||||
Ref("allowed_groups").
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||
}
|
||||
}
|
||||
|
||||
func (Group) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||
index.Fields("status"),
|
||||
index.Fields("platform"),
|
||||
index.Fields("subscription_type"),
|
||||
|
||||
@@ -162,8 +162,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -183,8 +181,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -220,8 +216,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -255,8 +249,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -286,8 +278,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -320,8 +310,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
|
||||
@@ -235,11 +235,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
|
||||
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
@@ -1479,11 +1477,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
|
||||
|
||||
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TestPublicSettingsInjectionPayload_SchemaDoesNotDrift guarantees the SSR
|
||||
// injection struct exposes every JSON field consumed by the frontend.
|
||||
//
|
||||
// Why this test exists: before we extracted a named PublicSettingsInjectionPayload
|
||||
// type, the inline struct was manually kept in sync with dto.PublicSettings and
|
||||
// drifted — ChannelMonitorEnabled / AvailableChannelsEnabled were missing, which
|
||||
// made the frontend read `undefined` on refresh and hide the "可用渠道" menu
|
||||
// until the async /api/v1/settings/public round-trip finished.
|
||||
//
|
||||
// This test compares the two JSON-tag sets and fails if injection is missing
|
||||
// any field that dto.PublicSettings exposes. Adding a new feature flag with
|
||||
// only a DTO entry will fail this test until the injection struct is updated.
|
||||
//
|
||||
// Intentional exclusions (fields present on dto.PublicSettings that SSR does
|
||||
// not need to inject) are listed in `dtoOnlyFields` below with a reason.
|
||||
func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) {
|
||||
injection := jsonTags(reflect.TypeOf(service.PublicSettingsInjectionPayload{}))
|
||||
dtoKeys := jsonTags(reflect.TypeOf(PublicSettings{}))
|
||||
|
||||
// Fields that legitimately live only on the DTO. Keep tiny; document each.
|
||||
dtoOnlyFields := map[string]string{
|
||||
// sora_client_enabled is an upstream-only field the fork does not surface.
|
||||
"sora_client_enabled": "upstream-only field, not used on this fork",
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for key := range dtoKeys {
|
||||
if _, ok := injection[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, allowed := dtoOnlyFields[key]; allowed {
|
||||
continue
|
||||
}
|
||||
missing = append(missing, key)
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Fatalf("service.PublicSettingsInjectionPayload is missing JSON fields present on dto.PublicSettings: %s\n"+
|
||||
"add the field to PublicSettingsInjectionPayload (and GetPublicSettingsForInjection), or "+
|
||||
"document the exclusion in dtoOnlyFields with a reason.", strings.Join(missing, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func jsonTags(t reflect.Type) map[string]struct{} {
|
||||
out := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
tag := f.Tag.Get("json")
|
||||
if tag == "" || tag == "-" {
|
||||
continue
|
||||
}
|
||||
name := strings.SplitN(tag, ",", 2)[0]
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
out[name] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -301,6 +301,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", apiKey.GroupID),
|
||||
zap.String("platform", platform),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
zap.String("platform", platform),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -525,6 +536,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", currentAPIKey.GroupID),
|
||||
zap.String("platform", platform),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -568,6 +586,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
zap.String("platform", platform),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,9 +15,8 @@ import (
|
||||
|
||||
// Alipay product codes.
|
||||
const (
|
||||
alipayProductCodePreCreate = "FACE_TO_FACE_PAYMENT"
|
||||
alipayProductCodeWapPay = "QUICK_WAP_WAY"
|
||||
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
|
||||
alipayProductCodeWapPay = "QUICK_WAP_WAY"
|
||||
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
|
||||
)
|
||||
|
||||
// Alipay response constants.
|
||||
@@ -31,9 +30,6 @@ var (
|
||||
alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
|
||||
return client.TradeWapPay(param)
|
||||
}
|
||||
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
|
||||
return client.TradePreCreate(ctx, param)
|
||||
}
|
||||
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
|
||||
return client.TradePagePay(param)
|
||||
}
|
||||
@@ -103,13 +99,13 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
|
||||
return map[string]string{"app_id": appID}
|
||||
}
|
||||
|
||||
// CreatePayment creates an Alipay payment using the following routing:
|
||||
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
|
||||
// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
|
||||
// - Desktop fallback: if precreate is unavailable for the merchant, fall back
|
||||
// to alipay.trade.page.pay and expose both pay_url and qr_code so the
|
||||
// frontend can render a QR while still allowing direct page open.
|
||||
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
// CreatePayment creates an Alipay payment using redirect-only flow:
|
||||
// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
|
||||
// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
|
||||
// new window; Alipay's own page then shows login/QR. We intentionally do
|
||||
// NOT encode the URL into a QR on the client (it isn't a scannable payload
|
||||
// and would produce an invalid scan result).
|
||||
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -127,7 +123,7 @@ func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
|
||||
if req.IsMobile {
|
||||
return a.createWapTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
return a.createDesktopTrade(ctx, client, req, notifyURL, returnURL)
|
||||
return a.createPagePayTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
|
||||
func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
@@ -149,48 +145,6 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
|
||||
if precreateErr == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, pagePayErr := a.createPagePayTrade(client, req, notifyURL, returnURL)
|
||||
if pagePayErr == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("alipay desktop payment failed: precreate=%v; pagepay=%w", precreateErr, pagePayErr)
|
||||
}
|
||||
|
||||
func (a *Alipay) createPrecreateTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL string) (*payment.CreatePaymentResponse, error) {
|
||||
param := alipay.TradePreCreate{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodePreCreate
|
||||
param.NotifyURL = notifyURL
|
||||
|
||||
rsp, err := alipayTradePreCreate(ctx, client, param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
|
||||
}
|
||||
if rsp == nil {
|
||||
return nil, fmt.Errorf("alipay TradePreCreate: empty response")
|
||||
}
|
||||
if rsp.IsFailure() {
|
||||
return nil, fmt.Errorf("alipay TradePreCreate failed: %s", rsp.Error.Error())
|
||||
}
|
||||
if strings.TrimSpace(rsp.QRCode) == "" {
|
||||
return nil, fmt.Errorf("alipay TradePreCreate: empty qr_code")
|
||||
}
|
||||
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
QRCode: rsp.QRCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
param := alipay.TradePagePay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
@@ -207,7 +161,6 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
QRCode: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -239,15 +192,7 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
|
||||
|
||||
amount, err := strconv.ParseFloat(result.TotalAmount, 64)
|
||||
if err != nil {
|
||||
amount, err = parseAlipayAmount(
|
||||
result.TotalAmount,
|
||||
result.ReceiptAmount,
|
||||
result.BuyerPayAmount,
|
||||
result.InvoiceAmount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay parse amount: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err)
|
||||
}
|
||||
|
||||
return &payment.QueryOrderResponse{
|
||||
@@ -283,14 +228,7 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
|
||||
|
||||
amount, err := strconv.ParseFloat(notification.TotalAmount, 64)
|
||||
if err != nil {
|
||||
amount, err = parseAlipayAmount(
|
||||
notification.TotalAmount,
|
||||
notification.ReceiptAmount,
|
||||
notification.BuyerPayAmount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay parse notification amount: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
|
||||
}
|
||||
|
||||
metadata := a.MerchantIdentityMetadata()
|
||||
@@ -368,20 +306,6 @@ func isTradeNotExist(err error) bool {
|
||||
return strings.Contains(err.Error(), alipayErrTradeNotExist)
|
||||
}
|
||||
|
||||
func parseAlipayAmount(values ...string) (float64, error) {
|
||||
for _, raw := range values {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
amount, err := strconv.ParseFloat(raw, 64)
|
||||
if err == nil {
|
||||
return amount, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no valid amount field")
|
||||
}
|
||||
|
||||
// Ensure interface compliance.
|
||||
var (
|
||||
_ payment.Provider = (*Alipay)(nil)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -137,22 +136,15 @@ func TestNewAlipay(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
origPreCreate := alipayTradePreCreate
|
||||
origPagePay := alipayTradePagePay
|
||||
origWapPay := alipayTradeWapPay
|
||||
t.Cleanup(func() {
|
||||
alipayTradePreCreate = origPreCreate
|
||||
alipayTradePagePay = origPagePay
|
||||
alipayTradeWapPay = origWapPay
|
||||
})
|
||||
|
||||
preCreateCalls := 0
|
||||
pagePayCalls := 0
|
||||
wapPayCalls := 0
|
||||
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
|
||||
preCreateCalls++
|
||||
return nil, errors.New("merchant does not have FACE_TO_FACE_PAYMENT")
|
||||
}
|
||||
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
|
||||
pagePayCalls++
|
||||
if param.OutTradeNo != "sub2_100" {
|
||||
@@ -169,7 +161,7 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
}
|
||||
|
||||
provider := &Alipay{}
|
||||
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
|
||||
resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{
|
||||
OrderID: "sub2_100",
|
||||
Amount: "88.00",
|
||||
Subject: "Balance recharge",
|
||||
@@ -177,9 +169,6 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if preCreateCalls != 1 {
|
||||
t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
|
||||
}
|
||||
if pagePayCalls != 1 {
|
||||
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
|
||||
}
|
||||
@@ -189,9 +178,6 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
if resp.PayURL == "" {
|
||||
t.Fatal("expected pay_url for desktop page pay")
|
||||
}
|
||||
if resp.QRCode != resp.PayURL {
|
||||
t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
|
||||
@@ -227,54 +213,6 @@ func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTradeUsesPrecreateForDesktopWhenAvailable(t *testing.T) {
|
||||
origPreCreate := alipayTradePreCreate
|
||||
origPagePay := alipayTradePagePay
|
||||
t.Cleanup(func() {
|
||||
alipayTradePreCreate = origPreCreate
|
||||
alipayTradePagePay = origPagePay
|
||||
})
|
||||
|
||||
preCreateCalls := 0
|
||||
pagePayCalls := 0
|
||||
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
|
||||
preCreateCalls++
|
||||
if param.ProductCode != alipayProductCodePreCreate {
|
||||
t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePreCreate)
|
||||
}
|
||||
return &alipay.TradePreCreateRsp{
|
||||
Error: alipay.Error{Code: alipay.CodeSuccess},
|
||||
QRCode: "https://qr.alipay.example.com/precreate-token",
|
||||
}, nil
|
||||
}
|
||||
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
|
||||
pagePayCalls++
|
||||
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
|
||||
}
|
||||
|
||||
provider := &Alipay{}
|
||||
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
|
||||
OrderID: "sub2_102",
|
||||
Amount: "66.00",
|
||||
Subject: "Balance recharge",
|
||||
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if preCreateCalls != 1 {
|
||||
t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
|
||||
}
|
||||
if pagePayCalls != 0 {
|
||||
t.Fatalf("page pay calls = %d, want 0", pagePayCalls)
|
||||
}
|
||||
if resp.QRCode != "https://qr.alipay.example.com/precreate-token" {
|
||||
t.Fatalf("qr_code = %q", resp.QRCode)
|
||||
}
|
||||
if resp.PayURL != "" {
|
||||
t.Fatalf("pay_url = %q, want empty for precreate", resp.PayURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlipayMerchantIdentityMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -289,19 +227,3 @@ func TestAlipayMerchantIdentityMetadata(t *testing.T) {
|
||||
t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAlipayAmount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
amount, err := parseAlipayAmount("", "88.00", "77.00")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if amount != 88 {
|
||||
t.Fatalf("amount = %v, want 88", amount)
|
||||
}
|
||||
|
||||
if _, err := parseAlipayAmount("", "not-a-number"); err == nil {
|
||||
t.Fatal("expected error when no valid amount field exists")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,4 +55,9 @@ const (
|
||||
|
||||
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
|
||||
ClaudeCodeVersion Key = "ctx_claude_code_version"
|
||||
|
||||
// IsSignatureRectifyRetry marks a retry request that was produced by the signature rectifier
|
||||
// (strip or pool-replace). The harvester consults this flag to avoid ingesting signatures
|
||||
// from retries, which would pollute the pool with signatures we ourselves injected.
|
||||
IsSignatureRectifyRetry Key = "ctx_is_signature_rectify_retry"
|
||||
)
|
||||
|
||||
@@ -313,6 +313,31 @@ func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]i
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CountByTLSFingerprintProfile 按 TLS 指纹模板 ID 聚合绑定账号数。
|
||||
// 走 108_add_tls_fingerprint_profile_id_index.sql 的表达式索引。
|
||||
func (r *accountRepository) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT (extra->>'tls_fingerprint_profile_id')::bigint AS profile_id, COUNT(*)
|
||||
FROM accounts
|
||||
WHERE deleted_at IS NULL AND extra ? 'tls_fingerprint_profile_id'
|
||||
GROUP BY profile_id`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
counts := make(map[int64]int)
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var n int
|
||||
if err := rows.Scan(&id, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[id] = n
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
const (
|
||||
stickySessionPrefix = "sticky_session:"
|
||||
)
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
@@ -41,12 +43,6 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
|
||||
@@ -3080,7 +3080,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
||||
query := `
|
||||
SELECT
|
||||
COALESCE(ul.group_id, 0) as group_id,
|
||||
COALESCE(g.name, '') as group_name,
|
||||
COALESCE(g.name, '(无分组)') as group_name,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||
|
||||
@@ -54,8 +54,13 @@ func TestAPIContracts(t *testing.T) {
|
||||
"username": "alice",
|
||||
"role": "user",
|
||||
"balance": 12.5,
|
||||
"balance_notify_enabled": false,
|
||||
"balance_notify_extra_emails": null,
|
||||
"balance_notify_threshold": null,
|
||||
"balance_notify_threshold_type": "",
|
||||
"concurrency": 5,
|
||||
"status": "active",
|
||||
"total_recharged": 0,
|
||||
"allowed_groups": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z",
|
||||
@@ -764,10 +769,13 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_cancel_rate_limit_unit": "",
|
||||
"payment_cancel_rate_limit_window_mode": "",
|
||||
"balance_low_notify_enabled": false,
|
||||
"account_quota_notify_enabled": false,
|
||||
"balance_low_notify_threshold": 0,
|
||||
"balance_low_notify_recharge_url": "",
|
||||
"account_quota_notify_enabled": false,
|
||||
"account_quota_notify_emails": [],
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"wechat_connect_enabled": false,
|
||||
"wechat_connect_app_id": "",
|
||||
"wechat_connect_app_secret_configured": false,
|
||||
@@ -975,7 +983,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"auth_source_default_wechat_subscriptions": [],
|
||||
"auth_source_default_wechat_grant_on_signup": false,
|
||||
"auth_source_default_wechat_grant_on_first_bind": false,
|
||||
"force_email_on_third_party_signup": false
|
||||
"force_email_on_third_party_signup": false,
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -1446,6 +1457,10 @@ func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, valu
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,10 @@ type AccountRepository interface {
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号
|
||||
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||
// CountByTLSFingerprintProfile 按 TLS 指纹模板 ID 聚合每个模板当前被多少账号绑定。
|
||||
// 返回 map[profile_id]count;未绑定任何账号的 profile 不出现在 map 中。
|
||||
// 查询走 108_add_tls_fingerprint_profile_id_index.sql 的表达式索引。
|
||||
CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error)
|
||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||
// for all accounts that have been synced from CRS.
|
||||
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
|
||||
|
||||
@@ -58,6 +58,10 @@ func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, valu
|
||||
panic("unexpected FindByExtraField call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
panic("unexpected CountByTLSFingerprintProfile call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
panic("unexpected ListCRSAccountIDs call")
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
|
||||
@@ -170,11 +170,11 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
return c.usersLoadBatch, c.usersLoadErr
|
||||
}
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
@@ -110,13 +110,12 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
// Gemini OAuth 401 second hit 会升级为 error(返回 None,交由默认错误逻辑处理)。
|
||||
name: "temp_unschedulable_401_second_hit_gemini_escalates",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Platform: PlatformGemini, // 非 Antigravity 平台 401 second hit 升级
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
@@ -131,7 +130,29 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
expected: ErrorPolicyNone, // Gemini 401 second hit 升级为 error
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_antigravity_no_escalation",
|
||||
account: &Account{
|
||||
ID: 16,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity, // Antigravity 跳过 401 升级,由 rules 正常处理
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled, // Antigravity 不升级,继续走规则匹配
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -143,7 +143,6 @@ func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, g
|
||||
func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
s.listByGroupCalls.Add(1)
|
||||
if s.err != nil {
|
||||
|
||||
@@ -82,6 +82,10 @@ func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@ func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key str
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -781,7 +781,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Acco
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
if req.RequestedModel != "" && !account.IsOpenAIPassthroughEnabled() && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||
|
||||
@@ -187,13 +187,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return model
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
@@ -235,78 +231,6 @@ func normalizeCodexModel(model string) string {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, rawTool := range tools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, rawTool := range tools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
||||
continue
|
||||
}
|
||||
if _, ok := toolMap["output_format"]; !ok {
|
||||
if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" {
|
||||
toolMap["output_format"] = value
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["output_compression"]; !ok {
|
||||
if value, exists := toolMap["compression"]; exists && value != nil {
|
||||
toolMap["output_compression"] = value
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["format"]; ok {
|
||||
delete(toolMap, "format")
|
||||
modified = true
|
||||
}
|
||||
if _, ok := toolMap["compression"]; ok {
|
||||
delete(toolMap, "compression")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
|
||||
if !hasOpenAIImageGenerationTool(reqBody) {
|
||||
return nil
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if !isOpenAIImageGenerationModel(model) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model)
|
||||
}
|
||||
|
||||
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
|
||||
if account == nil || account.Type == AccountTypeOAuth {
|
||||
return normalizeCodexModel(model)
|
||||
|
||||
@@ -217,42 +217,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
require.Equal(t, "bash", first["name"])
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "image_generation",
|
||||
"format": "png",
|
||||
"compression": 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modified := normalizeOpenAIResponsesImageGenerationTools(reqBody)
|
||||
require.True(t, modified)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
first, ok := tools[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "png", first["output_format"])
|
||||
require.Equal(t, 60, first["output_compression"])
|
||||
_, hasFormat := first["format"]
|
||||
require.False(t, hasFormat)
|
||||
_, hasCompression := first["compression"]
|
||||
require.False(t, hasCompression)
|
||||
}
|
||||
|
||||
func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) {
|
||||
err := validateOpenAIResponsesImageModel(map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{"type": "image_generation"},
|
||||
},
|
||||
}, "gpt-image-2")
|
||||
|
||||
require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
|
||||
|
||||
@@ -151,23 +151,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
{
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
if codexResult.NormalizedModel != "" {
|
||||
upstreamModel = codexResult.NormalizedModel
|
||||
modified := false
|
||||
if account.Type == AccountTypeOAuth {
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
modified = codexResult.Modified
|
||||
if codexResult.NormalizedModel != "" {
|
||||
upstreamModel = codexResult.NormalizedModel
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
} else {
|
||||
// 非 OAuth 账号也需要提取 system 消息并注入 instructions,
|
||||
// 否则上游 GPT-5/Codex 等模型会报 "Instructions are required"。
|
||||
if extractSystemMessagesFromInput(reqBody) {
|
||||
modified = true
|
||||
}
|
||||
if applyInstructions(reqBody, false) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
responsesBody, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||
if modified {
|
||||
responsesBody, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1503,7 +1503,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
if requestedModel != "" && !acc.IsOpenAIPassthroughEnabled() && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
@@ -1665,7 +1665,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
||||
if requestedModel != "" && !fresh.IsOpenAIPassthroughEnabled() && !fresh.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
@@ -1935,12 +1935,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||
}
|
||||
|
||||
if normalizeOpenAIResponsesImageGenerationTools(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
billingModel := account.GetMappedModel(reqModel)
|
||||
if billingModel != reqModel {
|
||||
@@ -1950,26 +1944,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("model", billingModel)
|
||||
}
|
||||
upstreamModel := billingModel
|
||||
if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": err.Error(),
|
||||
"param": "model",
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if hasOpenAIImageGenerationTool(reqBody) {
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s",
|
||||
reqModel,
|
||||
upstreamModel,
|
||||
account.Type,
|
||||
)
|
||||
}
|
||||
|
||||
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
|
||||
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
|
||||
|
||||
@@ -45,11 +45,8 @@ const (
|
||||
openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare"
|
||||
openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements"
|
||||
|
||||
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
openAIImageRequirementsDiff = "0fffff"
|
||||
openAIImageLifecycleTimeout = 2 * time.Minute
|
||||
openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
|
||||
openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
|
||||
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
openAIImageRequirementsDiff = "0fffff"
|
||||
)
|
||||
|
||||
type OpenAIImagesCapability string
|
||||
@@ -151,9 +148,6 @@ func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []b
|
||||
}
|
||||
|
||||
applyOpenAIImagesDefaults(req)
|
||||
if err := validateOpenAIImagesModel(req.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
|
||||
req.RequiredCapability = classifyOpenAIImagesCapability(req)
|
||||
return req, nil
|
||||
@@ -220,7 +214,7 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize))
|
||||
data, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read multipart field %s: %w", name, err)
|
||||
@@ -301,21 +295,6 @@ func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
|
||||
req.Model = "gpt-image-2"
|
||||
}
|
||||
|
||||
func isOpenAIImageGenerationModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-")
|
||||
}
|
||||
|
||||
func validateOpenAIImagesModel(model string) error {
|
||||
model = strings.TrimSpace(model)
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return nil
|
||||
}
|
||||
if model == "" {
|
||||
return fmt.Errorf("images endpoint requires an image model")
|
||||
}
|
||||
return fmt.Errorf("images endpoint requires an image model, got %q", model)
|
||||
}
|
||||
|
||||
func normalizeOpenAIImagesEndpointPath(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
switch {
|
||||
@@ -421,21 +400,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||
requestModel = mapped
|
||||
}
|
||||
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamModel := account.GetMappedModel(requestModel)
|
||||
if err := validateOpenAIImagesModel(upstreamModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s",
|
||||
strings.TrimSpace(parsed.Model),
|
||||
upstreamModel,
|
||||
parsed.Endpoint,
|
||||
account.Type,
|
||||
)
|
||||
forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -794,17 +759,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||
requestModel = mapped
|
||||
}
|
||||
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
|
||||
requestModel,
|
||||
parsed.Endpoint,
|
||||
account.Type,
|
||||
len(parsed.Uploads),
|
||||
)
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -890,18 +844,8 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
return nil, err
|
||||
}
|
||||
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d",
|
||||
conversationID,
|
||||
len(pointerInfos),
|
||||
countOpenAIFileServicePointerInfos(pointerInfos),
|
||||
countOpenAIDirectImageAssets(pointerInfos),
|
||||
)
|
||||
lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout)
|
||||
defer releaseLifecycleCtx()
|
||||
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
|
||||
polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID)
|
||||
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
|
||||
if pollErr != nil {
|
||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
|
||||
}
|
||||
@@ -909,11 +853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
}
|
||||
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
|
||||
if len(pointerInfos) == 0 {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID)
|
||||
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
|
||||
}
|
||||
|
||||
responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos)
|
||||
responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos)
|
||||
if err != nil {
|
||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
|
||||
}
|
||||
@@ -1340,11 +1283,8 @@ func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMess
|
||||
}
|
||||
|
||||
type openAIImagePointerInfo struct {
|
||||
Pointer string
|
||||
DownloadURL string
|
||||
B64JSON string
|
||||
MimeType string
|
||||
Prompt string
|
||||
Pointer string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
type openAIImageToolMessage struct {
|
||||
@@ -1396,6 +1336,10 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
matches := openAIImagePointerMatches(body)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
prompt := ""
|
||||
for _, path := range []string{
|
||||
"message.metadata.dalle.prompt",
|
||||
@@ -1407,12 +1351,11 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
||||
break
|
||||
}
|
||||
}
|
||||
matches := openAIImagePointerMatches(body)
|
||||
out := make([]openAIImagePointerInfo, 0, len(matches))
|
||||
for _, pointer := range matches {
|
||||
out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
|
||||
}
|
||||
return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt))
|
||||
return out
|
||||
}
|
||||
|
||||
func openAIImagePointerMatches(body []byte) []string {
|
||||
@@ -1451,72 +1394,27 @@ func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []open
|
||||
seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
|
||||
out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
|
||||
for _, item := range existing {
|
||||
if key := item.identityKey(); key != "" {
|
||||
seen[key] = item
|
||||
}
|
||||
seen[item.Pointer] = item
|
||||
out = append(out, item)
|
||||
}
|
||||
for _, item := range next {
|
||||
key := item.identityKey()
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if existingItem, ok := seen[key]; ok {
|
||||
merged := mergeOpenAIImagePointerInfo(existingItem, item)
|
||||
if merged != existingItem {
|
||||
if existingItem, ok := seen[item.Pointer]; ok {
|
||||
if existingItem.Prompt == "" && item.Prompt != "" {
|
||||
for i := range out {
|
||||
if out[i].identityKey() == key {
|
||||
out[i] = merged
|
||||
if out[i].Pointer == item.Pointer {
|
||||
out[i].Prompt = item.Prompt
|
||||
break
|
||||
}
|
||||
}
|
||||
seen[key] = merged
|
||||
}
|
||||
continue
|
||||
}
|
||||
seen[key] = item
|
||||
seen[item.Pointer] = item
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (i openAIImagePointerInfo) identityKey() string {
|
||||
switch {
|
||||
case strings.TrimSpace(i.Pointer) != "":
|
||||
return "pointer:" + strings.TrimSpace(i.Pointer)
|
||||
case strings.TrimSpace(i.DownloadURL) != "":
|
||||
return "download:" + strings.TrimSpace(i.DownloadURL)
|
||||
case strings.TrimSpace(i.B64JSON) != "":
|
||||
b64 := strings.TrimSpace(i.B64JSON)
|
||||
if len(b64) > 64 {
|
||||
b64 = b64[:64]
|
||||
}
|
||||
return "b64:" + b64
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo {
|
||||
merged := existing
|
||||
if strings.TrimSpace(merged.Pointer) == "" {
|
||||
merged.Pointer = next.Pointer
|
||||
}
|
||||
if strings.TrimSpace(merged.DownloadURL) == "" {
|
||||
merged.DownloadURL = next.DownloadURL
|
||||
}
|
||||
if strings.TrimSpace(merged.B64JSON) == "" {
|
||||
merged.B64JSON = next.B64JSON
|
||||
}
|
||||
if strings.TrimSpace(merged.MimeType) == "" {
|
||||
merged.MimeType = next.MimeType
|
||||
}
|
||||
if strings.TrimSpace(merged.Prompt) == "" {
|
||||
merged.Prompt = next.Prompt
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||
@@ -1526,26 +1424,6 @@ func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int {
|
||||
count := 0
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int {
|
||||
count := 0
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
|
||||
if !hasOpenAIFileServicePointerInfos(items) {
|
||||
return items
|
||||
@@ -1713,7 +1591,11 @@ func buildOpenAIImageResponse(
|
||||
}
|
||||
items := make([]responseItem, 0, len(pointers))
|
||||
for _, pointer := range pointers {
|
||||
data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer)
|
||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -1733,136 +1615,6 @@ func buildOpenAIImageResponse(
|
||||
return body, len(items), nil
|
||||
}
|
||||
|
||||
func resolveOpenAIImageBytes(
|
||||
ctx context.Context,
|
||||
client *req.Client,
|
||||
headers http.Header,
|
||||
conversationID string,
|
||||
pointer openAIImagePointerInfo,
|
||||
) ([]byte, error) {
|
||||
if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" {
|
||||
return base64.StdEncoding.DecodeString(normalized)
|
||||
}
|
||||
if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" {
|
||||
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
}
|
||||
if strings.TrimSpace(pointer.Pointer) == "" {
|
||||
return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data")
|
||||
}
|
||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
}
|
||||
|
||||
func normalizeOpenAIImageBase64(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||
if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) {
|
||||
raw = raw[idx+1:]
|
||||
}
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if _, err := base64.StdEncoding.DecodeString(raw); err != nil {
|
||||
return ""
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return nil
|
||||
}
|
||||
var decoded any
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return nil
|
||||
}
|
||||
var out []openAIImagePointerInfo
|
||||
walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out)
|
||||
return out
|
||||
}
|
||||
|
||||
func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) {
|
||||
switch value := node.(type) {
|
||||
case map[string]any:
|
||||
localPrompt := prompt
|
||||
for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} {
|
||||
if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" {
|
||||
localPrompt = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
}
|
||||
item := openAIImagePointerInfo{
|
||||
Prompt: localPrompt,
|
||||
Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]),
|
||||
DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]),
|
||||
B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]),
|
||||
MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]),
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"),
|
||||
strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"),
|
||||
isLikelyOpenAIImageDownloadURL(item.DownloadURL),
|
||||
normalizeOpenAIImageBase64(item.B64JSON) != "":
|
||||
*out = append(*out, item)
|
||||
}
|
||||
for _, child := range value {
|
||||
walkOpenAIImageInlineAssets(child, localPrompt, out)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range value {
|
||||
walkOpenAIImageInlineAssets(child, prompt, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...any) string {
|
||||
for _, value := range values {
|
||||
if s, ok := value.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isLikelyOpenAIImageDownloadURL(raw string) bool {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:image/") {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
return strings.Contains(lower, "/download") ||
|
||||
strings.Contains(lower, ".png") ||
|
||||
strings.Contains(lower, ".jpg") ||
|
||||
strings.Contains(lower, ".jpeg") ||
|
||||
strings.Contains(lower, ".webp")
|
||||
}
|
||||
|
||||
func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
if timeout <= 0 {
|
||||
return base, func() {}
|
||||
}
|
||||
return context.WithTimeout(base, timeout)
|
||||
}
|
||||
|
||||
func fetchOpenAIImageDownloadURL(
|
||||
ctx context.Context,
|
||||
client *req.Client,
|
||||
@@ -1954,7 +1706,7 @@ func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers h
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, newOpenAIImageStatusError(resp, "download image bytes failed")
|
||||
}
|
||||
return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes))
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func handleOpenAIImageBackendError(resp *req.Response) error {
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -104,56 +103,3 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNative
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.Nil(t, parsed)
|
||||
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
|
||||
}
|
||||
|
||||
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
|
||||
items := collectOpenAIImagePointers([]byte(`{
|
||||
"revised_prompt": "cat astronaut",
|
||||
"parts": [
|
||||
{"b64_json":"QUJD"},
|
||||
{"download_url":"https://files.example.com/image.png?sig=1"},
|
||||
{"asset_pointer":"file-service://file_123"}
|
||||
]
|
||||
}`))
|
||||
|
||||
require.Len(t, items, 3)
|
||||
var sawBase64, sawURL, sawPointer bool
|
||||
for _, item := range items {
|
||||
if item.B64JSON == "QUJD" {
|
||||
sawBase64 = true
|
||||
require.Equal(t, "cat astronaut", item.Prompt)
|
||||
}
|
||||
if item.DownloadURL == "https://files.example.com/image.png?sig=1" {
|
||||
sawURL = true
|
||||
}
|
||||
if item.Pointer == "file-service://file_123" {
|
||||
sawPointer = true
|
||||
}
|
||||
}
|
||||
require.True(t, sawBase64)
|
||||
require.True(t, sawURL)
|
||||
require.True(t, sawPointer)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
|
||||
data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{
|
||||
B64JSON: "data:image/png;base64,QUJD",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ABC"), data)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,6 @@ func TestNormalizeCodexModel(t *testing.T) {
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-image-2": "gpt-image-2",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
|
||||
@@ -812,16 +812,6 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
return openAIGPT54FallbackPricing
|
||||
}
|
||||
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} {
|
||||
if pricing, ok := s.pricingData[candidate]; ok {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -128,21 +128,6 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t
|
||||
require.Zero(t, got.LongContextInputTokenThreshold)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) {
|
||||
imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3}
|
||||
textPricing := &LiteLLMModelPricing{InputCostPerToken: 9}
|
||||
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-image-2": imagePricing,
|
||||
"gpt-5.4": textPricing,
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.GetModelPricing("gpt-image-3")
|
||||
require.Same(t, imagePricing, got)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"gpt-5.4": map[string]any{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (m *sessionWindowMockRepo) GetByCRSAccountID(context.Context, string) (*Acc
|
||||
func (m *sessionWindowMockRepo) FindByExtraField(context.Context, string, any) ([]Account, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) CountByTLSFingerprintProfile(context.Context) (map[int64]int, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@@ -546,8 +546,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
// channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval
|
||||
// (mirrors the monitor-level constraint but lives here so setting_service stays decoupled).
|
||||
const (
|
||||
channelMonitorIntervalMin = 15
|
||||
channelMonitorIntervalMax = 3600
|
||||
channelMonitorIntervalMin = 15
|
||||
channelMonitorIntervalMax = 3600
|
||||
channelMonitorIntervalFallback = 60
|
||||
)
|
||||
|
||||
@@ -578,8 +578,8 @@ func clampChannelMonitorInterval(v int) int {
|
||||
// ChannelMonitorRuntime is the lightweight view of the channel monitor feature
|
||||
// consumed by the runner and user-facing handlers.
|
||||
type ChannelMonitorRuntime struct {
|
||||
Enabled bool
|
||||
DefaultIntervalSeconds int
|
||||
Enabled bool
|
||||
DefaultIntervalSeconds int
|
||||
}
|
||||
|
||||
// GetChannelMonitorRuntime reads the channel monitor feature flags directly from
|
||||
@@ -628,56 +628,76 @@ func (s *SettingService) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection
|
||||
// This implements the web.PublicSettingsProvider interface
|
||||
// PublicSettingsInjectionPayload is the JSON shape embedded into HTML as
|
||||
// `window.__APP_CONFIG__` so the frontend can hydrate feature flags & site
|
||||
// config before the first XHR finishes.
|
||||
//
|
||||
// INVARIANT: every `json` tag here MUST also exist on handler/dto.PublicSettings.
|
||||
// If you forget a feature-flag field here, the frontend's
|
||||
// `cachedPublicSettings.xxx_enabled` will be `undefined` on refresh until the
|
||||
// async `/api/v1/settings/public` call returns — which causes opt-in menus
|
||||
// (strict `=== true`) to flicker off/on. See
|
||||
// frontend/src/utils/featureFlags.ts for the matching registry.
|
||||
//
|
||||
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
|
||||
// drift automatically (see setting_service_injection_test.go).
|
||||
type PublicSettingsInjectionPayload struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
|
||||
// Feature flags — MUST match the opt-in/opt-out registry in
|
||||
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
|
||||
// that hid the "可用渠道" menu on page refresh.
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
||||
// This implements the web.PublicSettingsProvider interface.
|
||||
func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
|
||||
settings, err := s.GetPublicSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return a struct that matches the frontend's expected format
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
Version string `json:"version,omitempty"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
}{
|
||||
return &PublicSettingsInjectionPayload{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
@@ -706,17 +726,20 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
|
||||
WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: s.version,
|
||||
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
|
||||
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
|
||||
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
|
||||
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -126,8 +126,8 @@ type SystemSettings struct {
|
||||
OpsMetricsIntervalSeconds int
|
||||
|
||||
// Channel Monitor feature
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
|
||||
// Available Channels feature (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
@@ -122,8 +122,8 @@ func TestShouldClearStickySession(t *testing.T) {
|
||||
{
|
||||
name: "overloaded account",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
OverloadUntil: &future,
|
||||
},
|
||||
requestedModel: "",
|
||||
|
||||
36
backend/migrations/130_fix_claude_code_template_userid.sql
Normal file
36
backend/migrations/130_fix_claude_code_template_userid.sql
Normal file
@@ -0,0 +1,36 @@
|
||||
-- Migration: 114_fix_claude_code_template_userid
|
||||
-- 113 的 seed 使用 legacy 格式的 metadata.user_id,但已部署环境此前是手工建的
|
||||
-- 「Claude Code 伪装」模板(用新版 JSON-string 格式 user_id),113 的 ON CONFLICT
|
||||
-- DO NOTHING 不会覆盖。本 migration 定向修复这一条历史记录及其下游监控快照。
|
||||
--
|
||||
-- 安全性:WHERE 条件同时匹配 (provider, name) + user_id 以 '{' 开头,
|
||||
-- 所以:
|
||||
-- - 用户自己改过 user_id(或者 seed 本来就是 legacy)→ LIKE 不中,保持原状
|
||||
-- - 用户改过 template name / provider → WHERE 不中,完全跳过
|
||||
-- 幂等:第二次跑时 user_id 已经是 legacy 格式,LIKE '{%' 不中,UPDATE 0 行。
|
||||
|
||||
UPDATE channel_monitor_request_templates
|
||||
SET body_override = jsonb_set(
|
||||
body_override,
|
||||
'{metadata,user_id}',
|
||||
'"user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"'::jsonb,
|
||||
false
|
||||
),
|
||||
updated_at = NOW()
|
||||
WHERE provider = 'anthropic'
|
||||
AND name = 'Claude Code 伪装'
|
||||
AND body_override #>> '{metadata,user_id}' LIKE '{%';
|
||||
|
||||
-- 同步已应用此模板的监控快照(监控采用 snapshot 语义,只更新那些明显还是 seed 原样的)。
|
||||
UPDATE channel_monitors m
|
||||
SET body_override = jsonb_set(
|
||||
m.body_override,
|
||||
'{metadata,user_id}',
|
||||
'"user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"'::jsonb,
|
||||
false
|
||||
)
|
||||
FROM channel_monitor_request_templates t
|
||||
WHERE m.template_id = t.id
|
||||
AND t.provider = 'anthropic'
|
||||
AND t.name = 'Claude Code 伪装'
|
||||
AND m.body_override #>> '{metadata,user_id}' LIKE '{%';
|
||||
@@ -0,0 +1,40 @@
|
||||
-- Migration: 115_cleanup_claude_code_mimicry_fields
|
||||
-- 清理 "Claude Code CLI 模拟套件 (A)" + "Signature Pool (B)" 回滚后遗留的 DB 状态。
|
||||
--
|
||||
-- 涉及回滚的功能:
|
||||
-- - 6d0e0562 feat(fingerprint): Claude Code CLI fingerprint mimicry suite
|
||||
-- - cfd95669 feat(tls-fingerprint): show binding count + fix randomized fingerprint visibility
|
||||
-- - 2df77c16/78de54b6/89d14a2 等 Signature Pool 相关 commits
|
||||
--
|
||||
-- 需要清理的字段:
|
||||
-- 1. accounts.extra->>'tls_fingerprint_randomized' — cfd95669 引入的随机指纹标记
|
||||
-- 2. accounts.extra->>'metadata' (内含 user_id) — sticky session UUID per Claude OAuth account
|
||||
-- 3. accounts.extra->>'sticky_session_user_id' — sticky session 备用键名(保险)
|
||||
--
|
||||
-- 需要清理的索引:
|
||||
-- - idx_accounts_tls_fp_profile_id — 来自 migration 108,加速绑定数聚合查询。
|
||||
-- 回滚后绑定数 UI 已移除,索引不再被任何查询使用,删除以释放空间。
|
||||
--
|
||||
-- 注意:上游已存在的 tls_fingerprint_profile_id / enable_tls_fingerprint 字段保留,
|
||||
-- 这些是上游 TLS fingerprint profile 功能本身的一部分,不在回滚范围内。
|
||||
|
||||
-- 1) 删除 cfd95669 引入的索引
|
||||
DROP INDEX IF EXISTS idx_accounts_tls_fp_profile_id;
|
||||
|
||||
-- 2) 清理 sticky session UUID(仅 Claude/Anthropic OAuth/SetupToken 账号会写入此字段)
|
||||
UPDATE accounts
|
||||
SET extra = extra - 'metadata'
|
||||
WHERE deleted_at IS NULL
|
||||
AND extra ? 'metadata';
|
||||
|
||||
-- 3) 清理随机指纹标记
|
||||
UPDATE accounts
|
||||
SET extra = extra - 'tls_fingerprint_randomized'
|
||||
WHERE deleted_at IS NULL
|
||||
AND extra ? 'tls_fingerprint_randomized';
|
||||
|
||||
-- 4) 清理可能残留的 sticky session 备用字段
|
||||
UPDATE accounts
|
||||
SET extra = extra - 'sticky_session_user_id'
|
||||
WHERE deleted_at IS NULL
|
||||
AND extra ? 'sticky_session_user_id';
|
||||
Reference in New Issue
Block a user