Compare commits

...

37 Commits

Author SHA1 Message Date
Wesley Liddick
5d1c12e60e Merge pull request #1943 from AyeSt0/fix/openai-responses-preoutput-failover
fix(openai): 修复 Responses 流式失败前置事件导致无法 failover
2026-04-25 15:43:00 +08:00
AyeSt0
5b63a9b02d fix(openai): fail over before responses stream output 2026-04-25 15:09:40 +08:00
Wesley Liddick
641e61073f Merge pull request #1940 from 4fuu/fix/bump-codex-cli-version-to-0.125.0
fix(openai): bump codex CLI version from 0.104.0 to 0.125.0
2026-04-25 14:57:51 +08:00
shaw
095f457c57 feat(openai): port /responses/compact account support flow (PR #1555)
vansour/sub2api#1555 的 OpenAI compact 能力建模手工移植到当前 main:账号
级 compact 状态/auto-force_on-force_off 模式、compact-only 模型映射、调度器
tier 分层(已支持 > 未知 > 已知不支持)、管理后台 compact 主动探测,以及对应
i18n/状态徽章。普通 /responses 流量行为不变,无数据库迁移。
2026-04-25 14:52:58 +08:00
4fuu
1e57e88e43 fix(openai): bump codex CLI version from 0.104.0 to 0.125.0
The hardcoded codex CLI version (0.104.0) causes upstream rejection
when using gpt-5.5 with compact, as the server treats the request
as an outdated client and returns 400/502.

Update codexCLIVersion, codexCLIUserAgent, and openAICodexProbeVersion
to 0.125.0 to match the current Codex CLI release.

Fixes #1933, #1887, #1865
Related: #1609, #1298, #849
2026-04-25 05:26:33 +00:00
Wesley Liddick
b95ffce244 Merge pull request #1772 from KnowSky404/fix/openai-test-state-reconciliation
[codex] reconcile OpenAI admin test rate-limit state
2026-04-25 10:02:21 +08:00
shaw
8f28a834f8 fix(payment): 同时启用易支付和 Stripe 时显示 Stripe 按钮
VISIBLE_METHOD_ALIASES 漏了 stripe,导致 getVisibleMethods 把后端返回
的 stripe 过滤掉。点 Stripe 按钮时省略 method 查询参数,让落地页渲染
完整的 Payment Element。
2026-04-25 09:46:27 +08:00
shaw
7424c73b05 chore: remove unused model IDs 2026-04-25 09:04:34 +08:00
Wesley Liddick
1afd81b019 Merge pull request #1920 from Wuxie233/fix/responses-web-search-tool-types
fix(apicompat): recognize web_search_20250305 / google_search in Responses→Anthropic tool conversion
2026-04-25 09:00:37 +08:00
shaw
732d6495ea chore(gateway): fix lint issues from cc-mimicry-parity merge
- staticcheck QF1001: apply De Morgan's law to the OAuth-mimic header
  passthrough guard (`!(a && b)` → `a != ... || !b`).
- unused: drop `isClaudeCodeRequest`, which became dead after PR #1914
  switched both `/v1/messages` and `/count_tokens` paths to unconditional
  `account.IsOAuth()` mimicry. The lowercase helper `isClaudeCodeClient`
  is kept (still referenced by `TestIsClaudeCodeClient`).
2026-04-25 08:58:57 +08:00
Wesley Liddick
6d20ab8082 Merge pull request #1914 from keh4l/feat/cc-mimicry-parity
fix(claude): align Claude Code OAuth mimicry with real CLI traffic
2026-04-25 08:54:04 +08:00
shaw
aa8ee33b0a refactor(affiliate): tighten DI and harden inviter code validation
- Drop SetAffiliateService setters and ProvideAuthService /
  ProvidePaymentService / ProvideUserHandler wrappers in favor of direct
  Wire constructor injection. AffiliateService has no back-edge to
  Auth/Payment/User, so the indirection was never required.
- Change RegisterWithVerification's variadic affiliateCode to a fixed
  parameter; adjust all call sites.
- Validate aff_code length and charset in BindInviterByCode before any
  DB lookup, eliminating timing-side-channel and useless DB roundtrips
  on malformed input.
- Make affiliate cache invalidation synchronous; surface Redis errors
  via the project logger instead of swallowing them in a detached
  goroutine.
- Add an integration test guarding cross-layer tx propagation in
  AccrueQuota and a unit test pinning the aff_code format rules.
2026-04-25 08:44:18 +08:00
Wuxie233
5f630fbb19 fix(apicompat): recognize web_search_20250305 / google_search in Responses to Anthropic tool conversion 2026-04-25 01:09:51 +08:00
keh4l
bdbd2916f5 fix(gateway): skip client header passthrough on OAuth mimicry path
Root cause of persistent third-party detection: sub2api's
buildUpstreamRequest transparently forwards client headers via
allowedHeaders whitelist (addHeaderRaw) before applying mimicry
overrides. When third-party clients (opencode, etc.) send their own
anthropic-beta / user-agent / x-stainless-* / x-claude-code-session-id
values, these get appended to the request alongside our injected
headers, creating an inconsistent header set that Anthropic detects.

Parrot's build_upstream_headers constructs exactly 9 headers from
scratch and never forwards anything from the client. This is why
'same opencode version, some users work some don't' — different
opencode configs/versions send different header combinations.

Fix: when tokenType=oauth and mimicClaudeCode=true, skip the
client header passthrough loop entirely. The subsequent
applyClaudeCodeMimicHeaders + ApplyFingerprint + beta merge
pipeline constructs all necessary headers from our controlled values.

Also: remove systemIncludesClaudeCodePrompt gate — OAuth accounts
now unconditionally rewrite system (even if client already sent a
Claude Code-style prompt), ensuring billing attribution block is
always present.
2026-04-25 00:43:38 +08:00
keh4l
6dc89765fd fix(gateway): always apply full mimicry for OAuth accounts regardless of client identity
Before: isClaudeCodeRequest() checked whether the client looks like a
real Claude Code CLI (UA, system prompt, X-App header, metadata format).
If it looked like Claude Code, all mimicry was skipped — the assumption
being that a real CLI needs no help.

Problem: third-party tools like opencode partially impersonate Claude
Code (sending claude-cli UA + claude-code beta + CC system prompt) but
miss critical details (billing attribution block, tool-name obfuscation,
cache breakpoints, full beta set). Some users' opencode instances pass
the isClaudeCodeRequest check, causing sub2api to skip mimicry entirely,
while Anthropic still detects the request as third-party.

This explains why 'same opencode version, some users work, some don't'
— it depends on which opencode features/config trigger the validator.

Fix: OAuth accounts now unconditionally run the full mimicry pipeline,
matching Parrot's behavior (Parrot never checks client identity).
This is safe because our mimicry is strictly more complete than any
third-party client's partial impersonation.

Changed:
  - /v1/messages path: remove isClaudeCode gate
  - /v1/messages/count_tokens path: same
2026-04-25 00:26:37 +08:00
keh4l
f3233db01f fix(gateway): apply D/E/F mimicry to native /v1/messages and count_tokens paths
The previous commit only wired stripMessageCacheControl,
addMessageCacheBreakpoints, and tool-name obfuscation into
applyClaudeCodeOAuthMimicryToBody (used by /chat/completions and
/responses). The native /v1/messages path and count_tokens path
have their own independent mimicry code blocks and were missed.

Now all three entry points share the same D/E/F pipeline:
  - /v1/messages (gateway_service.go forwardAnthropic)
  - /v1/messages/count_tokens (gateway_service.go countTokens)
  - OpenAI compat (applyClaudeCodeOAuthMimicryToBody)
2026-04-24 23:16:32 +08:00
keh4l
6e12578bc5 feat(gateway): port Parrot tool-name obfuscation + message cache breakpoints
Implements the remaining three parity items with Parrot cc_mimicry:

  D) Tool-name obfuscation
     - Dynamic mapping when tools.length > 5 (matches Parrot threshold).
       Fake names follow {prefix}{name[:3]}{i:02d} (e.g. 'manage_bas00').
       Go port of random.Random(hash(tuple(names))) uses fnv64a seed +
       math/rand; byte-exact reproduction is impossible (Python hash vs
       Go hash), but the two invariants that matter are preserved:
         * same input tool_names yield identical mapping (cache hit)
         * prefix pool is shuffled (names look distributed)
     - Static prefix map (sessions_ -> cc_sess_, session_ -> cc_ses_)
       applied as fallback, matching Parrot TOOL_NAME_REWRITES verbatim.
     - Server tools (web_search_20250305, computer_*, etc.) are NOT
       renamed; only type=='function' and type=='custom' tools are.
     - tool_choice.name is rewritten in sync (only when type=='tool').
     - Response side: bytes-level replace on every SSE chunk / JSON
       body at 6 injection points (standard stream/non-stream,
       passthrough stream/non-stream, chat_completions stream +
       non-stream, responses stream + non-stream). Reverse mapping
       applied longest-fake-name-first to prevent substring conflicts
       (parity with Parrot _restore_tool_names_in_chunk).
     - tool_choice is no longer unconditionally deleted in
       normalizeClaudeOAuthRequestBody — Parrot passes it through.

  E) tools[-1] cache_control breakpoint
     - Injected as {type:ephemeral, ttl:<DefaultCacheControlTTL>} when
       the last tool has no cache_control. Client-provided ttl is
       passed through unchanged (repo-wide policy).

  F) messages cache_control strategy
     - stripMessageCacheControl removes every client-provided
       messages[*].content[*].cache_control (multi-turn stability).
     - addMessageCacheBreakpoints then injects two stable breakpoints:
       (1) last message, and (2) second-to-last user turn when
       messages.length >= 4.
     - Combined with the system block breakpoint and tools[-1]
       breakpoint, this gives exactly the 4 breakpoints Anthropic
       allows per request.

Non-trivial implementation details to be aware of when rebasing:

  * Two new files, no upstream collision:
      gateway_tool_rewrite.go       (D + E algorithms)
      gateway_messages_cache.go     (F strip + breakpoints)
  * Two new feature calls bolted onto the tail of
    applyClaudeCodeOAuthMimicryToBody in gateway_service.go — rebase
    conflicts will be ~10 lines maximum.
  * Response-side injection points all wrap their existing write with
    reverseToolNamesIfPresent(c, ...), preserving original behavior
    when no mapping is stored (static prefix rollback still runs).
  * Non-stream chat/responses switched from c.JSON to
    json.Marshal + c.Data so bytes-level replace is possible.
  * Retry bodies (FilterThinkingBlocksForRetry,
    FilterSignatureSensitiveBlocksForRetry, RectifyThinkingBudget)
    only prune blocks — they preserve the already-obfuscated tool
    names, so no extra mapping re-application is needed.

Manual QA: end-to-end scenario verified with 6 tools (above threshold)
and tool_choice.type=='tool'. Obfuscation + restore roundtrip shown
in test logs; then removed the temp test file.

Tests (16 new):
  - buildDynamicToolMap stability + below-threshold guard
  - sanitizeToolName precedence (dynamic > static)
  - restoreToolNamesInBytes longest-first + static rollback
  - applyToolNameRewriteToBody skips server tools + syncs tool_choice
  - applyToolsLastCacheBreakpoint defaults to 5m + passes client ttl
  - stripMessageCacheControl + addMessageCacheBreakpoints in the
    1/4/string-content cases + second-to-last user turn selection
  - buildToolNameRewriteFromBody ReverseOrdered is desc-by-fake-length
  - fake name shape follows Parrot {prefix}{head3}{i:02d}
2026-04-24 23:16:32 +08:00
keh4l
a25faecadd feat(gateway): align body shape with real Claude Code CLI defaults
Three field-level alignments in normalizeClaudeOAuthRequestBody to
match real Claude Code CLI traffic byte-for-byte:

  1. temperature: previously deleted unconditionally; now passes
     through client value, defaults to 1 when absent (real CLI
     always sends temperature, default 1).

  2. max_tokens: defaults to 128000 when absent (real CLI default).

  3. context_management: when thinking.type is enabled/adaptive
     and the client did not provide context_management, inject
     {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}
     to mirror real CLI behavior.

tool_choice removal is unchanged (Claude Code OAuth credentials
do not allow client-supplied tool_choice).

Tests updated:
  - gateway_body_order_test.go: temperature/max_tokens are now
    expected in output; tool_choice still removed.
  - gateway_prompt_test.go: system array is now 2 blocks
    (billing + cc prompt), assertions adjusted.
  - gateway_anthropic_apikey_passthrough_test.go: same 2-block
    assertion.
2026-04-24 23:16:32 +08:00
keh4l
5862e2d8d9 feat(gateway): add billing attribution block with cc_version fingerprint
Real Claude Code CLI always sends a 2-block system array:

  [0] {"type":"text", "text":"x-anthropic-billing-header: cc_version=X.Y.Z.{fp}; cc_entrypoint=cli; cch=00000;"}
  [1] {"type":"text", "text":"You are Claude Code...", "cache_control":{...}}

Before this commit, sub2api's mimicry path only produced block [1].
The missing billing block is one of the primary third-party detection
signals Anthropic uses for Claude-Code-scoped OAuth tokens.

New file gateway_billing_block.go ports the fingerprint algorithm
(byte-for-byte from Parrot cc_mimicry.py:compute_fingerprint):
pick chars at positions [4,7,20] of the first user text, then
`sha256(SALT + chars + cc_version)[:3]`.

  - claude/constants.go: CLICurrentVersion = "2.1.92" (must match UA)
  - gateway_billing_block.go: computeClaudeCodeFingerprint +
    buildBillingAttributionBlockJSON + extractFirstUserText
  - gateway_service.go: rewriteSystemForNonClaudeCode now emits both
    blocks in order; cch=00000 is filled in later by
    signBillingHeaderCCH in buildUpstreamRequest.

Downstream compat note: syncBillingHeaderVersion's regex
`cc_version=\d+\.\d+\.\d+` only matches the semver triple,
leaving the `.{fp}` suffix intact when rewriting in buildUpstreamRequest.
2026-04-24 23:16:32 +08:00
keh4l
66d6454535 feat(claude): add ttl to cache_control with default 5m
Real Claude CLI traffic sends cache_control as
`{"type":"ephemeral","ttl":"1h"}`. Our previous payload only
sent `{"type":"ephemeral"}`, which is a bytewise mismatch with
the official CLI and one more third-party detection signal.

Policy: client-provided ttl is always passed through unchanged.
Proxy-generated cache_control blocks default to 5m (vs Parrot's 1h)
to avoid burning the 1h cache budget on automatic breakpoints while
still aligning with the `ttl` field being present.

  - claude/constants.go: DefaultCacheControlTTL = "5m"
  - apicompat/types.go: new AnthropicCacheControl type with TTL field;
    AnthropicTool gains optional CacheControl pointer so the mimicry
    path can attach a cache breakpoint to tools[-1] later.
  - service/gateway_service.go: anthropicCacheControlPayload gains TTL;
    marshalAnthropicSystemTextBlock and rewriteSystemForNonClaudeCode
    emit ttl=5m by default.
2026-04-24 23:16:32 +08:00
keh4l
165553cfb0 fix(gateway): use full beta list in buildUpstreamRequest mimicry path
The previous commit added FullClaudeCodeMimicryBetas() but the two
call sites in buildUpstreamRequest still hardcoded the old 3-token
subset. Anthropic now checks the complete set of beta tokens to
decide if a request qualifies as Claude Code. Wire them up:

  - /v1/messages mimic path: requiredBetas = FullClaudeCodeMimicryBetas()
  - /v1/messages/count_tokens mimic path: same + BetaTokenCounting

Haiku models keep the 2-token exemption (BetaOAuth + InterleaveThinking).
2026-04-24 23:16:32 +08:00
keh4l
b5467d610a fix(gateway): apply full Claude Code mimicry on /chat/completions and /responses
Before: the OpenAI-compat forwarders only called injectClaudeCodePrompt,
which prepends the Claude Code banner but leaves the rest of the body
in its original non-Claude-Code shape. The codebase already admits this
is insufficient (see the comment on rewriteSystemForNonClaudeCode in
gateway_service.go: "仅前置追加 Claude Code 提示词无法通过检测").

Effect: OAuth accounts served through /v1/chat/completions or /v1/responses
were detected as third-party apps and bled plan quota with:

    Third-party apps now draw from your extra usage, not your plan limits.

Fix:
  - apicompat.AnthropicRequest: add Metadata json.RawMessage so metadata
    survives the OpenAI->Anthropic->Marshal round trip; without it the
    downstream rewrite has no user_id to work with.
  - service: extract applyClaudeCodeOAuthMimicryToBody, a ParsedRequest-free
    variant of the /v1/messages mimicry pipeline
    (rewriteSystemForNonClaudeCode + normalizeClaudeOAuthRequestBody +
    metadata.user_id injection) so the OpenAI-compat forwarders can reuse it.
  - service: add buildOAuthMetadataUserIDFromBody + hashBodyForSessionSeed
    for the same reason (no ParsedRequest at the call site).
  - ForwardAsChatCompletions / ForwardAsResponses: replace the 3-line
    prompt-prepend with the full mimicry pipeline.
  - applyClaudeCodeMimicHeaders: set x-client-request-id per-request
    (real Claude CLI always does); missing/duplicated values are one more
    third-party fingerprint signal.

No change to the native /v1/messages path: it already called the full
pipeline, we only lift those helpers into a reusable function.

Tests:
  - go build ./... passes
  - go test ./internal/service/... ./internal/pkg/apicompat/... passes
  - lsp_diagnostics clean on all touched files
  - pre-existing failures in internal/config are unrelated (env-sensitive
    tests that also fail on upstream main)
2026-04-24 23:16:32 +08:00
keh4l
57ff97960d chore(claude): bump mimicked CLI to 2.1.92 and extend anthropic-beta list
Align Claude Code mimicry constants with the latest real CLI traffic
(see Parrot's src/transform/cc_mimicry.py). Anthropic now uses the full
set of anthropic-beta tokens to decide whether a request counts as
"official Claude Code"; requests missing tokens that real CLI ships
today are demoted to third-party usage:

  Third-party apps now draw from your extra usage, not your plan limits.

Changes:
  - claude/constants.go: add new beta tokens (prompt-caching-scope,
    effort, redact-thinking, context-management, extended-cache-ttl) and
    expose FullClaudeCodeMimicryBetas() for the OAuth mimicry path.
  - claude/constants.go: bump default User-Agent to claude-cli/2.1.92.
  - identity_service.go: bump defaultFingerprint User-Agent accordingly.

No behavioral change for clients that already send a newer UA (fingerprint
merge still prefers the incoming value).
2026-04-24 23:16:32 +08:00
Wesley Liddick
5b5db88550 Merge pull request #1897 from VpSanta33/codex/invite-affiliate-rebate
feat: 新增邀请返利功能,并支持后台配置返利比例
2026-04-24 22:36:53 +08:00
VpSanta33
f03de00cb9 feat: add affiliate invite rebate flow and admin rebate-rate setting 2026-04-24 22:22:26 +08:00
Wesley Liddick
76aae5aa74 Merge pull request #1911 from gaoren002/fix/codex-responses-payload-normalization-mainbase
fix(openai): normalize codex responses payloads
2026-04-24 21:37:32 +08:00
gaoren002
27ee141c1e fix(openai): preserve mcp tool call ids 2026-04-24 13:24:21 +00:00
gaoren002
e65574dea9 fix(openai): normalize codex responses payloads 2026-04-24 12:03:19 +00:00
Wesley Liddick
1ce9dc03f9 Merge pull request #1895 from gaoren002/fix/codex-spark-limitations
fix(openai): handle codex spark model limitations
2026-04-24 19:57:42 +08:00
Wesley Liddick
15ce914a62 Merge pull request #1910 from slovx2/fix/codex-tool-call-ids
fix(openai): 修复 Codex 工具调用 call_id 处理
2026-04-24 19:56:03 +08:00
song
959af1c8f6 fix(openai): preserve codex tool call ids 2026-04-24 19:31:49 +08:00
gaoren002
c4d496da18 fix(openai): handle codex spark model limitations 2026-04-24 07:42:31 +00:00
KnowSky404
f3ea878ba2 chore: trigger PR checks 2026-04-24 11:32:41 +08:00
KnowSky404
d80469ea35 test: fix OpenAI account test helper calls after rebase 2026-04-24 11:32:41 +08:00
KnowSky404
5fc30ea964 test: cover openai admin test state transitions 2026-04-24 11:32:41 +08:00
KnowSky404
f68909a68b fix: reconcile openai admin test rate-limit state 2026-04-24 11:32:41 +08:00
github-actions[bot]
d162604f32 chore: sync VERSION to 0.1.117 [skip ci] 2026-04-24 01:40:02 +00:00
101 changed files with 6824 additions and 406 deletions

View File

@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@@ -1 +1 @@
0.1.116
0.1.117

View File

@@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
affiliateRepository := repository.NewAffiliateRepository(client, db)
affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -192,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
@@ -221,20 +226,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
sqlDB, err := repository.ProvideSQLDB(client)
if err != nil {
return nil, err
}
channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
@@ -245,9 +241,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -263,6 +260,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,

View File

@@ -652,6 +652,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
Mode string `json:"mode"`
}
type SyncFromCRSRequest struct {
@@ -682,7 +683,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
// Error already sent via SSE, just log
return
}

View File

@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
AffiliateRebateRate: settings.AffiliateRebateRate,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
@@ -338,6 +339,7 @@ type UpdateSettingsRequest struct {
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
@@ -468,6 +470,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
affiliateRebateRate := previousSettings.AffiliateRebateRate
if req.AffiliateRebateRate != nil {
affiliateRebateRate = *req.AffiliateRebateRate
}
if affiliateRebateRate < service.AffiliateRebateRateMin {
affiliateRebateRate = service.AffiliateRebateRateMin
}
if affiliateRebateRate > service.AffiliateRebateRateMax {
affiliateRebateRate = service.AffiliateRebateRateMax
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -1119,6 +1131,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
@@ -1433,6 +1446,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -1738,6 +1752,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.AffiliateRebateRate != after.AffiliateRebateRate {
changed = append(changed, "affiliate_rebate_rate")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}

View File

@@ -48,6 +48,7 @@ type RegisterRequest struct {
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
AffCode string `json:"aff_code"` // 邀请返利码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
_, user, err := h.authService.RegisterWithVerification(
c.Request.Context(),
req.Email,
req.Password,
req.VerifyCode,
req.PromoCode,
req.InvitationCode,
req.AffCode,
)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
nil,
nil,
options.defaultSubAssigner,
nil,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService

View File

@@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()

View File

@@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
nil,
nil,
nil,
nil,
)
return &AuthHandler{

View File

@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`

View File

@@ -130,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
@@ -153,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)

View File

@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)

View File

@@ -238,6 +238,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
requireCompact := isOpenAIRemoteCompactPath(c)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -256,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
requireCompact,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
@@ -263,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
}
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@@ -644,6 +650,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
@@ -1167,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))

View File

@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
Save(context.Background())
require.NoError(t, err)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()

View File

@@ -14,10 +14,11 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
@@ -26,12 +27,14 @@ func NewUserHandler(
authService *service.AuthService,
emailService *service.EmailService,
emailCache service.EmailCache,
affiliateService *service.AffiliateService,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
}
}
@@ -159,6 +162,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, profileResp)
}
// GetAffiliate returns the current user's affiliate details.
// GET /api/v1/user/aff
func (h *UserHandler) GetAffiliate(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
// TransferAffiliateQuota transfers all available affiliate quota into current balance.
// POST /api/v1/user/aff/transfer
func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"transferred_quota": transferred,
"balance": balance,
})
}
type StartIdentityBindingRequest struct {
Provider string `json:"provider" binding:"required"`
RedirectTo string `json:"redirect_to"`

View File

@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()

View File

@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
case "web_search", "google_search", "web_search_20250305":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",

View File

@@ -12,17 +12,23 @@ import "encoding/json"
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
// Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
// 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
// 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
// RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
// user_id进而导致请求被归类为第三方 app。
Metadata json.RawMessage `json:"metadata,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
}
// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
type AnthropicCacheControl struct {
Type string `json:"type"` // "ephemeral"
TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m由 Anthropic 判定)
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.

View File

@@ -4,6 +4,12 @@ package claude
// Claude Code 客户端相关常量
// Beta header 常量
//
// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04
// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
// 原因Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
// 缺少任何"官方 Claude Code 请求才会带"的 beta都会被降级到第三方额度
// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
@@ -12,6 +18,13 @@ const (
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
// 新增(对齐官方 CLI 2.1.9x 以来的流量)
BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
BetaEffort = "effort-2025-11-24"
BetaRedactThinking = "redact-thinking-2026-02-12"
BetaContextManagement = "context-management-2025-06-27"
BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
const DefaultCacheControlTTL = "5m"
// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver
// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
const CLICurrentVersion = "2.1.92"
// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
// 用于 OAuth 账号伪装成 Claude Code 时使用。
// 顺序与真实 CLI 抓包一致。
//
// 使用建议:
// - OAuth 账号 + 非 haiku追加这整份列表再按需保留 client 带来的 beta。
// - OAuth 账号 + haikuAnthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
func FullClaudeCodeMimicryBetas() []string {
return []string{
BetaClaudeCode,
BetaOAuth,
BetaInterleavedThinking,
BetaPromptCachingScope,
BetaEffort,
BetaRedactThinking,
BetaContextManagement,
BetaExtendedCacheTTL,
}
}
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
// 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
"User-Agent": "claude-cli/2.1.92 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",

View File

@@ -0,0 +1,14 @@
package repository
import "testing"
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
updates := map[string]any{
"openai_compact_supported": true,
"openai_compact_checked_at": "2026-04-10T10:00:00Z",
}
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
}
}

View File

@@ -0,0 +1,420 @@
package repository
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
const (
affiliateCodeLength = 12
affiliateCodeMaxAttempts = 12
)
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type affiliateRepository struct {
client *dbent.Client
}
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
return &affiliateRepository{client: client}
}
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
return ensureUserAffiliateWithClient(ctx, client, userID)
}
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
client := clientFromContext(ctx, r.client)
return queryAffiliateByCode(ctx, client, code)
}
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
var bound bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
return err
}
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
inviterID, userID,
)
if err != nil {
return fmt.Errorf("bind inviter: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
bound = false
return nil
}
if _, err = txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
inviterID,
); err != nil {
return fmt.Errorf("increment inviter aff_count: %w", err)
}
bound = true
return nil
})
if err != nil {
return false, err
}
return bound, nil
}
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
if amount <= 0 {
return false, nil
}
var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
amount, inviterID,
)
if err != nil {
return err
}
affected, _ := res.RowsAffected()
if affected == 0 {
applied = false
return nil
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
applied = true
return nil
})
if err != nil {
return false, err
}
return applied, nil
}
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64
var newBalance float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS (
SELECT aff_quota::double precision AS amount
FROM user_affiliates
WHERE user_id = $1
AND aff_quota > 0
FOR UPDATE
),
cleared AS (
UPDATE user_affiliates ua
SET aff_quota = 0,
updated_at = NOW()
FROM claimed c
WHERE ua.user_id = $1
RETURNING c.amount
)
SELECT amount
FROM cleared`, userID)
if err != nil {
return fmt.Errorf("claim affiliate quota: %w", err)
}
if !rows.Next() {
_ = rows.Close()
if err := rows.Err(); err != nil {
return err
}
return service.ErrAffiliateQuotaEmpty
}
if err := rows.Scan(&transferred); err != nil {
_ = rows.Close()
return err
}
if err := rows.Close(); err != nil {
return err
}
if transferred <= 0 {
return service.ErrAffiliateQuotaEmpty
}
affected, err := txClient.User.Update().
Where(user.IDEQ(userID)).
AddBalance(transferred).
AddTotalRecharged(transferred).
Save(txCtx)
if err != nil {
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
}
if affected == 0 {
return service.ErrUserNotFound
}
newBalance, err = queryUserBalance(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
}
return nil
})
if err != nil {
return 0, 0, err
}
return transferred, newBalance, nil
}
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
if limit <= 0 {
limit = 100
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.created_at
FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id
WHERE ua.inviter_id = $1
ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
invitees := make([]service.AffiliateInvitee, 0)
for rows.Next() {
var item service.AffiliateInvitee
var createdAt time.Time
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
return nil, err
}
item.CreatedAt = &createdAt
invitees = append(invitees, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return invitees, nil
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client())
}
tx, err := r.client.Tx(ctx)
if err != nil {
return fmt.Errorf("begin affiliate transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx, tx.Client()); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit affiliate transaction: %w", err)
}
return nil
}
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
summary, err := queryAffiliateByUserID(ctx, client, userID)
if err == nil {
return summary, nil
}
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
return nil, err
}
for i := 0; i < affiliateCodeMaxAttempts; i++ {
code, codeErr := generateAffiliateCode()
if codeErr != nil {
return nil, codeErr
}
_, insertErr := client.ExecContext(ctx, `
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING`, userID, code)
if insertErr == nil {
break
}
if isAffiliateUniqueViolation(insertErr) {
continue
}
return nil, insertErr
}
return queryAffiliateByUserID(ctx, client, userID)
}
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
inviter_id,
aff_count,
aff_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
return &out, nil
}
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
inviter_id,
aff_count,
aff_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE aff_code = $1
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
return &out, nil
}
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
rows, err := client.QueryContext(ctx,
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
userID,
)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return 0, err
}
return 0, service.ErrUserNotFound
}
var balance float64
if err := rows.Scan(&balance); err != nil {
return 0, err
}
return balance, nil
}
func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("generate affiliate code: %w", err)
}
for i := range buf {
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
}
return string(buf), nil
}
func isAffiliateUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return string(pqErr.Code) == "23505"
}
return false
}

View File

@@ -0,0 +1,184 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value float64
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value int
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 5.5,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.NoError(t, err)
require.InDelta(t, 12.34, transferred, 1e-9)
require.InDelta(t, 17.84, balance, 1e-9)
affQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
require.InDelta(t, 0.0, affQuota, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 17.84, persistedBalance, 1e-9)
ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount)
}
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
// that already carries a transaction (via dbent.NewTxContext), repo.withTx
// must reuse that tx rather than opening a nested one. If this invariant
// breaks, AccrueQuota would commit independently and survive a rollback of
// the outer tx, which would violate payment_fulfillment's all-or-nothing
// semantics.
func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
ctx := context.Background()
outerTx, err := integrationEntClient.Tx(ctx)
require.NoError(t, err, "begin outer tx")
// Defensive cleanup: if any require.* below fires before the explicit
// Rollback, this prevents the tx from leaking until container teardown.
// Rollback is idempotent at the driver level (extra rollback returns an
// error we ignore).
t.Cleanup(func() { _ = outerTx.Rollback() })
client := outerTx.Client()
txCtx := dbent.NewTxContext(ctx, outerTx)
inviter := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
invitee := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
repo := NewAffiliateRepository(client, integrationDB)
_, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
require.NoError(t, err)
_, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
require.NoError(t, err)
bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true")
// Visible inside the outer tx.
innerQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
require.InDelta(t, 3.5, innerQuota, 1e-9)
// Roll back the outer tx; if AccrueQuota had opened its own inner tx and
// committed it, the rows would still be visible to the global client.
require.NoError(t, outerTx.Rollback())
rows, err := integrationEntClient.QueryContext(ctx,
"SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
inviter.ID, invitee.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next())
var postRollbackCount int
require.NoError(t, rows.Scan(&postRollbackCount))
require.Equal(t, 0, postRollbackCount,
"AccrueQuota must propagate the outer tx — found persisted rows after rollback")
}
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 3.21,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
require.InDelta(t, 0.0, transferred, 1e-9)
require.InDelta(t, 0.0, balance, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 3.21, persistedBalance, 1e-9)
}

View File

@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository,
// Cache implementations
NewGatewayCache,

View File

@@ -715,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -895,6 +896,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,

View File

@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,

View File

@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}

View File

@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
user.GET("/aff", h.User.GetAffiliate)
user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)

View File

@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
const (
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
OpenAICompactModeAuto = "auto"
// OpenAICompactModeForceOn always treats the account as compact-supported.
OpenAICompactModeForceOn = "force_on"
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
OpenAICompactModeForceOff = "force_off"
)
func normalizeOpenAICompactMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAICompactModeForceOn:
return OpenAICompactModeForceOn
case OpenAICompactModeForceOff:
return OpenAICompactModeForceOff
default:
return OpenAICompactModeAuto
}
}
func stringMappingFromRaw(raw any) map[string]string {
switch mapping := raw.(type) {
case map[string]any:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
if str, ok := value.(string); ok {
result[key] = str
}
}
if len(result) == 0 {
return nil
}
return result
case map[string]string:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
result[key] = value
}
return result
default:
return nil
}
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
// Missing or invalid values fall back to "auto".
func (a *Account) GetOpenAICompactMode() string {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return OpenAICompactModeAuto
}
mode, _ := a.Extra["openai_compact_mode"].(string)
return normalizeOpenAICompactMode(mode)
}
// OpenAICompactSupportKnown reports whether compact capability is known for this
// account and, when known, whether it is supported.
func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
if a == nil || !a.IsOpenAI() {
return false, false
}
switch a.GetOpenAICompactMode() {
case OpenAICompactModeForceOn:
return true, true
case OpenAICompactModeForceOff:
return false, true
}
if a.Extra == nil {
return false, false
}
supported, ok := a.Extra["openai_compact_supported"].(bool)
if !ok {
return false, false
}
return supported, true
}
// AllowsOpenAICompact reports whether the account may be considered for compact
// requests. Unknown capability remains allowed to avoid breaking older accounts
// before an explicit probe has been run.
func (a *Account) AllowsOpenAICompact() bool {
if a == nil || !a.IsOpenAI() {
return false
}
supported, known := a.OpenAICompactSupportKnown()
if !known {
return true
}
return supported
}
// GetCompactModelMapping returns compact-only model remapping configuration.
// This mapping is intended for /responses/compact only and does not affect
// normal /responses traffic.
func (a *Account) GetCompactModelMapping() map[string]string {
if a == nil || a.Credentials == nil {
return nil
}
return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
}
// ResolveCompactMappedModel resolves compact-only model remapping and reports
// whether a compact-specific mapping rule matched.
func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetCompactModelMapping()
if len(mapping) == 0 {
return requestedModel, false
}
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
return mappedModel, true
}
return requestedModel, false
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""

View File

@@ -0,0 +1,369 @@
package service
import "testing"
func TestAccountGetOpenAICompactMode(t *testing.T) {
tests := []struct {
name string
account *Account
want string
}{
{
name: "nil account defaults to auto",
want: OpenAICompactModeAuto,
},
{
name: "non openai account defaults to auto",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: OpenAICompactModeAuto,
},
{
name: "missing extra defaults to auto",
account: &Account{
Platform: PlatformOpenAI,
},
want: OpenAICompactModeAuto,
},
{
name: "invalid mode falls back to auto",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " invalid "},
},
want: OpenAICompactModeAuto,
},
{
name: "force on is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
},
want: OpenAICompactModeForceOn,
},
{
name: "force off is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": "force_off"},
},
want: OpenAICompactModeForceOff,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.GetOpenAICompactMode(); got != tt.want {
t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
}
})
}
}
func TestAccountOpenAICompactSupportKnown(t *testing.T) {
tests := []struct {
name string
account *Account
wantSupported bool
wantKnown bool
}{
{
name: "nil account is unknown",
wantSupported: false,
wantKnown: false,
},
{
name: "non openai account is unknown",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: false,
wantKnown: false,
},
{
name: "force on overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOn,
"openai_compact_supported": false,
},
},
wantSupported: true,
wantKnown: true,
},
{
name: "force off overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOff,
"openai_compact_supported": true,
},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto true is known supported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: true,
wantKnown: true,
},
{
name: "auto false is known unsupported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto without probe state remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
wantSupported: false,
wantKnown: false,
},
{
name: "invalid probe field remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": "true"},
},
wantSupported: false,
wantKnown: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
}
})
}
}
func TestAccountAllowsOpenAICompact(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "nil account does not allow compact",
want: false,
},
{
name: "non openai account does not allow compact",
account: &Account{
Platform: PlatformAnthropic,
},
want: false,
},
{
name: "unknown openai account remains allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
want: true,
},
{
name: "supported openai account is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
want: true,
},
{
name: "unsupported openai account is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
want: false,
},
{
name: "force on is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: true,
},
{
name: "force off is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.AllowsOpenAICompact(); got != tt.want {
t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
}
})
}
}
func TestAccountGetCompactModelMapping(t *testing.T) {
tests := []struct {
name string
account *Account
want map[string]string
}{
{
name: "nil account returns nil",
want: nil,
},
{
name: "missing credentials returns nil",
account: &Account{
Platform: PlatformOpenAI,
},
want: nil,
},
{
name: "map any is converted",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
"invalid": 1,
},
},
},
want: map[string]string{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
{
name: "map string string is copied",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]string{
"gpt-*": "compact-*",
},
},
},
want: map[string]string{
"gpt-*": "compact-*",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.GetCompactModelMapping()
if !equalStringMap(got, tt.want) {
t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestAccountResolveCompactMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no compact mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact compact mapping matches",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
expectedMatch: true,
},
{
name: "exact passthrough counts as match",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "longest wildcard wins",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-*": "fallback-compact",
"gpt-5.4*": "gpt-5.4-openai-compact",
"gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
},
},
requestedModel: "gpt-5.4-mini",
expectedModel: "gpt-5.4-mini-openai-compact",
expectedMatch: true,
},
{
name: "missing compact mapping reports unmatched",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.3": "gpt-5.3-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Credentials: tt.credentials,
}
gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func equalStringMap(left, right map[string]string) bool {
if len(left) != len(right) {
return false
}
for key, want := range right {
if got, ok := left[key]; !ok || got != want {
return false
}
}
return true
}

View File

@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID, prompt)
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
// Align test routing with gateway behavior: OpenAI accounts apply normal
// account model mapping, and compact mode applies compact-only mapping on top.
testModelID = account.GetMappedModel(testModelID)
if mode == AccountTestModeCompact {
testModelID = resolveOpenAICompactForwardModel(account, testModelID)
return s.testOpenAICompactConnection(c, account, testModelID)
}
// Route to image generation test if an image model is selected
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
ctx := c.Request.Context()
authToken := ""
apiURL := ""
isOAuth := false
chatgptAccountID := ""
switch {
case account.IsOAuth():
isOAuth = true
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
apiURL = chatgptCodexAPIURL + "/compact"
chatgptAccountID = account.GetChatGPTAccountID()
case account.Type == AccountTypeAPIKey:
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("Originator", "codex_cli_rs")
req.Header.Set("User-Agent", codexCLIUserAgent)
req.Header.Set("Version", codexCLIVersion)
probeSessionID := compactProbeSessionID(account.ID)
req.Header.Set("Session_ID", probeSessionID)
req.Header.Set("Conversation_ID", probeSessionID)
if isOAuth {
req.Host = "chatgpt.com"
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
updates = mergeExtraUpdates(updates, codexUpdates)
}
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
var resetAt *time.Time
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
resetAt = calculated
} else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
t := time.Unix(*unixTs, 0)
resetAt = &t
}
if resetAt == nil {
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
return
}
now := time.Now()
account.RateLimitedAt = &now
account.RateLimitResetAt = resetAt
if account.Status == StatusError {
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
return
}
account.Status = StatusActive
account.ErrorMessage = ""
}
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
@@ -1261,7 +1412,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()

View File

@@ -0,0 +1,199 @@
package service
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusNotFound,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`404 page not found`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.Error(t, err)
updates := <-updateCalls
require.Equal(t, false, updates["openai_compact_supported"])
require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"error"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 3,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://example.com/v1",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 4,
Name: "openai-apikey-default",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
<-updateCalls
}

View File

@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
clearedErrorID int64
setErrorID int64
setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
r.clearedErrorID = id
return nil
}
func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
@@ -111,11 +125,11 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
@@ -130,15 +144,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
require.Empty(t, repo.updatedExtra)
}
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 78,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 79,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "stale 403",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusError, account.Status)
require.Equal(t, "stale 403", account.ErrorMessage)
require.Nil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 80,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
require.Zero(t, repo.rateLimitedID)
require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}

View File

@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
openAICodexProbeVersion = "0.104.0"
openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存

View File

@@ -0,0 +1,314 @@
package service
import (
"context"
"errors"
"math"
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
)
const (
affiliateInviteesLimit = 100
// affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength.
affiliateCodeFormatLength = 12
)
// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used
// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9).
var affiliateCodeValidChar = func() [256]bool {
var tbl [256]bool
for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") {
tbl[c] = true
}
return tbl
}()
func isValidAffiliateCodeFormat(code string) bool {
if len(code) != affiliateCodeFormatLength {
return false
}
for i := 0; i < len(code); i++ {
if !affiliateCodeValidChar[code[i]] {
return false
}
}
return true
}
type AffiliateSummary struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AffiliateInvitee struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
CreatedAt *time.Time `json:"created_at,omitempty"`
}
type AffiliateDetail struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
Invitees []AffiliateInvitee `json:"invitees"`
}
type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
}
type AffiliateService struct {
repo AffiliateRepository
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCacheService *BillingCacheService
}
func NewAffiliateService(repo AffiliateRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
return &AffiliateService{
repo: repo,
settingRepo: settingRepo,
authCacheInvalidator: authCacheInvalidator,
billingCacheService: billingCacheService,
}
}
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.EnsureUserAffiliate(ctx, userID)
}
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil {
return nil, err
}
invitees, err := s.listInvitees(ctx, userID)
if err != nil {
return nil, err
}
return &AffiliateDetail{
UserID: summary.UserID,
AffCode: summary.AffCode,
InviterID: summary.InviterID,
AffCount: summary.AffCount,
AffQuota: summary.AffQuota,
AffHistoryQuota: summary.AffHistoryQuota,
Invitees: invitees,
}, nil
}
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
code := strings.ToUpper(strings.TrimSpace(rawCode))
if code == "" {
return nil
}
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
if err != nil {
return err
}
if selfSummary.InviterID != nil {
return nil
}
inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
if err != nil {
if errors.Is(err, ErrAffiliateProfileNotFound) {
return ErrAffiliateCodeInvalid
}
return err
}
if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
return ErrAffiliateCodeInvalid
}
bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
if err != nil {
return err
}
if !bound {
return ErrAffiliateAlreadyBound
}
return nil
}
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
if s == nil || s.repo == nil {
return 0, nil
}
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
return 0, nil
}
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
if err != nil {
return 0, err
}
if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
return 0, nil
}
rebateRatePercent := s.loadAffiliateRebateRatePercent(ctx)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
if _, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID); err != nil {
return 0, err
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
if err != nil {
return 0, err
}
if !applied {
return 0, nil
}
return rebate, nil
}
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
if s == nil || s.repo == nil {
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
if err != nil {
return 0, 0, err
}
if transferred > 0 {
s.invalidateAffiliateCaches(ctx, userID)
}
return transferred, balance, nil
}
func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
if err != nil {
return nil, err
}
for i := range invitees {
invitees[i].Email = maskEmail(invitees[i].Email)
}
return invitees, nil
}
func (s *AffiliateService) loadAffiliateRebateRatePercent(ctx context.Context) float64 {
if s == nil || s.settingRepo == nil {
return AffiliateRebateRateDefault
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
if err != nil {
return AffiliateRebateRateDefault
}
rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil {
return AffiliateRebateRateDefault
}
if math.IsNaN(rate) || math.IsInf(rate, 0) {
return AffiliateRebateRateDefault
}
if rate < AffiliateRebateRateMin {
return AffiliateRebateRateMin
}
if rate > AffiliateRebateRateMax {
return AffiliateRebateRateMax
}
return rate
}
func roundTo(v float64, scale int) float64 {
factor := math.Pow10(scale)
return math.Round(v*factor) / factor
}
func maskEmail(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return ""
}
at := strings.Index(email, "@")
if at <= 0 || at >= len(email)-1 {
return "***"
}
local := email[:at]
domain := email[at+1:]
dot := strings.LastIndex(domain, ".")
maskedLocal := maskSegment(local)
if dot <= 0 || dot >= len(domain)-1 {
return maskedLocal + "@" + maskSegment(domain)
}
domainName := domain[:dot]
tld := domain[dot:]
return maskedLocal + "@" + maskSegment(domainName) + tld
}
func maskSegment(s string) string {
r := []rune(s)
if len(r) == 0 {
return "***"
}
if len(r) == 1 {
return string(r[0]) + "***"
}
return string(r[0]) + "***"
}
func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
}
}
}

View File

@@ -0,0 +1,91 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type affiliateSettingRepoStub struct {
value string
err error
}
func (s *affiliateSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, s.err }
func (s *affiliateSettingRepoStub) GetValue(context.Context, string) (string, error) {
if s.err != nil {
return "", s.err
}
return s.value, nil
}
func (s *affiliateSettingRepoStub) Set(context.Context, string, string) error { return s.err }
func (s *affiliateSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
if s.err != nil {
return nil, s.err
}
return map[string]string{}, nil
}
func (s *affiliateSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return s.err
}
func (s *affiliateSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
if s.err != nil {
return nil, s.err
}
return map[string]string{}, nil
}
func (s *affiliateSettingRepoStub) Delete(context.Context, string) error { return s.err }
func TestAffiliateRebateRatePercentSemantics(t *testing.T) {
t.Parallel()
svc := &AffiliateService{settingRepo: &affiliateSettingRepoStub{value: "1"}}
rate := svc.loadAffiliateRebateRatePercent(context.Background())
require.Equal(t, 1.0, rate)
svc.settingRepo = &affiliateSettingRepoStub{value: "0.2"}
rate = svc.loadAffiliateRebateRatePercent(context.Background())
require.Equal(t, 0.2, rate)
}
func TestMaskEmail(t *testing.T) {
t.Parallel()
require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
require.Equal(t, "x***@d***", maskEmail("x@domain"))
require.Equal(t, "", maskEmail(""))
}
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
want bool
}{
{"valid canonical", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
{"too short", "ABCDEFGHJKL", false},
{"too long", "ABCDEFGHJKLMN", false},
{"contains excluded letter I", "IBCDEFGHJKLM", false},
{"contains excluded letter O", "OBCDEFGHJKLM", false},
{"contains excluded digit 0", "0BCDEFGHJKLM", false},
{"contains excluded digit 1", "1BCDEFGHJKLM", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
{"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset
{"ascii punctuation", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
})
}
}

View File

@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
nil,
)
}

View File

@@ -72,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -98,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -110,6 +112,7 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
}
// RegisterWithVerification 用户注册支持邮件验证、优惠码和邀请码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if s.affiliateService != nil {
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
// 邀请返利码绑定失败不影响注册,只记录日志
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
}
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {

View File

@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,

View File

@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}

View File

@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}
@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}

View File

@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}

View File

@@ -18,6 +18,13 @@ const (
RoleUser = domain.RoleUser
)
// Affiliate rebate settings
const (
AffiliateRebateRateDefault = 20.0
AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0
)
// Platform constants
const (
PlatformAnthropic = domain.PlatformAnthropic
@@ -87,6 +94,7 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例百分比0-100
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址

View File

@@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.True(t, system.IsArray(), "system should be an array")
require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
arr := system.Array()
require.Len(t, arr, 2, "system array should have billing block + cc prompt block")
require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:")
require.Contains(t, arr[0].Get("text").String(), "cc_version=")
require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String())
require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String())
// 原始 system prompt 应迁移至 messages 中
messages := gjson.GetBytes(upstream.lastBody, "messages")

View File

@@ -0,0 +1,98 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
)
// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。
//
// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致;
// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致,
// 进一步触发 Anthropic 的第三方检测。
const fingerprintSalt = "59cf53e54c78"
// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法:
//
// 1. 取 messages 中第一条 role=user 的纯文本(首块 text
// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐)
// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符
//
// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint与官方 CLI 字节对齐。
// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。
func computeClaudeCodeFingerprint(body []byte, version string) string {
firstText := extractFirstUserText(body)
indices := []int{4, 7, 20}
chars := make([]byte, 0, 3)
for _, i := range indices {
if i < len(firstText) {
chars = append(chars, firstText[i])
} else {
chars = append(chars, '0')
}
}
sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version))
return hex.EncodeToString(sum[:])[:3]
}
// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。
// 兼容 string 和 []block 两种 content 格式。
func extractFirstUserText(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
first := ""
messages.ForEach(func(_, msg gjson.Result) bool {
if msg.Get("role").String() != "user" {
return true
}
content := msg.Get("content")
if content.Type == gjson.String {
first = content.String()
return false
}
if content.IsArray() {
content.ForEach(func(_, block gjson.Result) bool {
if block.Get("type").String() == "text" {
first = block.Get("text").String()
return false
}
return true
})
return false
}
return false
})
return first
}
// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。
//
// 形态严格对齐真实 Claude Code CLI
//
// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"}
//
// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段
// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。
//
// 此 block 不带 cache_control与真实 CLI 一致cache breakpoint 由后续的
// Claude Code prompt block 承担)。
func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) {
if cliVersion == "" {
return nil, fmt.Errorf("cliVersion required")
}
fp := computeClaudeCodeFingerprint(body, cliVersion)
text := fmt.Sprintf(
"x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;",
cliVersion, fp,
)
return json.Marshal(map[string]string{
"type": "text",
"text": text,
})
}

View File

@@ -41,12 +41,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`)
require.Contains(t, resultStr, `"temperature":0.2`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
require.Contains(t, resultStr, `"max_tokens":128000`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {

View File

@@ -85,15 +85,16 @@ func (s *GatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts.
// Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -312,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, ccResp)
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, ccResp)
}
return &ForwardResult{
RequestID: requestID,
@@ -383,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
// Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
// c 可能持有请求侧注入的 ToolNameRewrite无则仅做静态前缀还原。
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
return true // client disconnected
}
return false

View File

@@ -82,15 +82,16 @@ func (s *GatewayService) ForwardAsResponses(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints).
// OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -331,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
if respBytes, err := json.Marshal(responsesResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, responsesResp)
}
return &ForwardResult{
RequestID: requestID,
@@ -419,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
@@ -439,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
fmt.Fprint(c.Writer, out) //nolint:errcheck
}
c.Writer.Flush()
}

View File

@@ -0,0 +1,141 @@
package service
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
// 与 Parrot _strip_message_cache_control 语义一致。
//
// 为什么必须整体清空:客户端(特别是 Claude Code经常把 cache_control 打在
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
// 变成中间某条cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
// 统一由代理重新打断点addMessageCacheBreakpoints才能在多轮间稳定。
func stripMessageCacheControl(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
msgIdx := -1
messages.ForEach(func(_, msg gjson.Result) bool {
msgIdx++
content := msg.Get("content")
if !content.IsArray() {
return true
}
blockIdx := -1
content.ForEach(func(_, block gjson.Result) bool {
blockIdx++
if !block.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
if next, err := sjson.DeleteBytes(body, path); err == nil {
body = next
}
return true
})
return true
})
return body
}
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
// 1. 最后一条 message
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
//
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
// + tools[-1] 的断点共同构成最多 4 个断点Anthropic 上限)。
//
// cache_control ttl 策略:
// - 若目标 block 已有 cache_control.ttl → 不覆盖
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
func addMessageCacheBreakpoints(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
arr := messages.Array()
if len(arr) == 0 {
return body
}
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
if len(arr) >= 4 {
userCount := 0
for i := len(arr) - 1; i >= 0; i-- {
if arr[i].Get("role").String() != "user" {
continue
}
userCount++
if userCount == 2 {
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
break
}
}
}
return body
}
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
// 的最后一个 content block 上。若 content 是 string先升级成单块 text 数组
// (对齐 Parrot _inject_cache_on_msg 的行为)。
//
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
content := msg.Get("content")
if content.Type == gjson.String {
text := content.String()
blockRaw := fmt.Sprintf(
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
mustJSONString(text), claude.DefaultCacheControlTTL,
)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
body = next
}
return body
}
if !content.IsArray() {
return body
}
contentArr := content.Array()
if len(contentArr) == 0 {
return body
}
lastBlockIdx := len(contentArr) - 1
lastBlock := contentArr[lastBlockIdx]
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
return body
}
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
existingCC := lastBlock.Get("cache_control")
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
body = next
}
return body
}
// mustJSONString 把一个 Go string 序列化为合法 JSON string含引号
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
func mustJSONString(s string) string {
return fmt.Sprintf("%q", s)
}

View File

@@ -378,16 +378,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
// system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
// system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution block (x-anthropic-billing-header: cc_version=...;)
// [1] Claude Code prompt block (带 cache_control)
systemArr, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array, got %T", parsed["system"])
require.Len(t, systemArr, 1, "system array should have exactly 1 block")
systemBlock, ok := systemArr[0].(map[string]any)
require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)")
billingBlock, ok := systemArr[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", billingBlock["type"])
require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:")
require.Contains(t, billingBlock["text"], "cc_version=")
require.Contains(t, billingBlock["text"], "cc_entrypoint=cli")
require.Contains(t, billingBlock["text"], "cch=00000")
systemBlock, ok := systemArr[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", systemBlock["type"])
require.Equal(t, tt.wantSystemText, systemBlock["text"])
cc, ok := systemBlock["cache_control"].(map[string]any)
require.True(t, ok, "system block should have cache_control")
require.True(t, ok, "cc prompt block should have cache_control")
require.Equal(t, "ephemeral", cc["type"])
// 检查 messages

View File

@@ -850,6 +850,7 @@ func (s *GatewayService) hashContent(content string) string {
type anthropicCacheControlPayload struct {
Type string `json:"type"`
TTL string `json:"ttl,omitempty"`
}
type anthropicSystemTextBlockPayload struct {
@@ -898,7 +899,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b
Text: text,
}
if includeCacheControl {
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
block.CacheControl = &anthropicCacheControlPayload{
Type: "ephemeral",
TTL: claude.DefaultCacheControlTTL,
}
}
return json.Marshal(block)
}
@@ -1074,19 +1078,52 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if gjson.GetBytes(out, "temperature").Exists() {
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
// temperature真实 Claude Code CLI 总是发送 temperature默认 1客户端可覆盖
// 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。
// 策略:客户端传了什么就透传;没传则补默认 1。
if !gjson.GetBytes(out, "temperature").Exists() {
if next, ok := setJSONValueBytes(out, "temperature", 1); ok {
out = next
modified = true
}
}
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
// max_tokens真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok {
out = next
modified = true
}
}
// context_managementthinking.type 为 enabled/adaptive 时,真实 CLI 会自动
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
// 客户端显式传了就透传;否则按 CLI 行为补齐。
if !gjson.GetBytes(out, "context_management").Exists() {
thinkingType := gjson.GetBytes(out, "thinking.type").String()
if thinkingType == "enabled" || thinkingType == "adaptive" {
const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}`
if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok {
out = next
modified = true
}
}
}
// tool_choice与 Parrot 对齐,不再无条件删除。
// - 客户端传了 {"type":"tool","name":"X"} → 保留结构name 由
// applyToolNameRewriteToBody 同步映射为假名
// - 其他形态auto/any/none原样透传
// 如果 body 里完全没有 tools空数组tool_choice 没意义时才删除
if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
}
if !modified {
return body, modelID
}
@@ -1128,6 +1165,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号"
// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。
//
// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode +
// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层
// (ForwardAsChatCompletions / ForwardAsResponses) 复用。
//
// 未抽离之前OpenAI 协议兼容层仅做 injectClaudeCodePrompt前置追加
// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic
// 第三方检测";那条注释就是本函数存在的根因。
//
// 参数:
// - ctx / c用于读取指纹和 gateway settingsc 可为 nil如 count_tokens
// - account必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。
// - body已经 marshal 成 Anthropic /v1/messages 格式的请求体。
// - systemRawbody 中原始 system 字段(用于判断是否需要 rewrite
// - model最终会发给上游的模型 ID用于 haiku 旁路 + metadata 版本选择)。
//
// 返回:改写后的 body。即使中间任何一步失败也会退化成原 body不会 panic
func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
systemRaw any,
model string,
) []byte {
if account == nil || !account.IsOAuth() || len(body) == 0 {
return body
}
systemRewritten := false
if !strings.Contains(strings.ToLower(model), "haiku") {
body = rewriteSystemForNonClaudeCode(body, systemRaw)
systemRewritten = true
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil && c != nil && c.Request != nil {
if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil {
mimicMPT := false
if s.settingService != nil {
_, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
}
if !mimicMPT {
if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = uid
}
}
}
}
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
// 1) strip先清除客户端的 messages[*].cache_control多轮稳定性
// 2) breakpoints再注入 2 个断点(最后一条 + 倒数第二个 user turn
// 3) tool rewrite最后改 tools[*].name / tool_choice.name 并在 tools[-1]
// 上打断点mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
if c != nil {
c.Set(toolNameRewriteKey, rw)
}
} else {
body = applyToolsLastCacheBreakpoint(body)
}
return body
}
// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体,
// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。
//
// 与 buildOAuthMetadataUserID 的唯一区别:
// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。
// - 如果 body 里已经存在 metadata.user_id则返回空由 ensureClaudeOAuthMetadataUserID
// 自行决定是否覆盖)。
func (s *GatewayService) buildOAuthMetadataUserIDFromBody(
ctx context.Context,
account *Account,
fp *Fingerprint,
body []byte,
) string {
_ = ctx
if account == nil {
return ""
}
if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
userID = generateClientID()
}
sessionID := uuid.NewString()
if hash := hashBodyForSessionSeed(body); hash != "" {
sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash))
}
var uaVersion string
if fp != nil {
uaVersion = ExtractCLIVersion(fp.UserAgent)
}
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。
// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。
func hashBodyForSessionSeed(body []byte) string {
if len(body) == 0 {
return ""
}
sum := sha256.Sum256(body)
return fmt.Sprintf("%x", sum[:16])
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
func GenerateSessionUUID(seed string) string {
return generateSessionUUID(seed)
@@ -3552,16 +3718,6 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent)
}
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型string / []any / nil
// 避免 type switch 中 json.RawMessage底层 []byte无法匹配 case string / case []any / case nil 的问题。
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
@@ -3738,17 +3894,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. system 替换为 Claude Code 标准提示词array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
// 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution blockcc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;
// [1] "You are Claude Code..." prompt block带 cache_control 作为稳定缓存断点)
//
// billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的
// signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload
// 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。
billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion)
ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if billingErr != nil || ccErr != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr)
return body
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock}))
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
@@ -3985,15 +4144,21 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
})
}
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
// OAuth 账号无条件走完整 mimicry与 Parrot 对齐。
// 不再检查 isClaudeCodeRequest —— 即使客户端自称 Claude Codeopencode 等
// 第三方工具会伪装 UA / X-App / system prompt它的伪装往往不完整缺 billing
// block / 工具名混淆 / cache 策略等),被 Anthropic 判为 third-party。
// 无条件覆盖不会对真正的 Claude Code 造成问题,因为我们的伪装更完整。
shouldMimicClaudeCode := account.IsOAuth()
if shouldMimicClaudeCode {
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
// 与 Parrot 对齐OAuth 账号无条件重写 system即使客户端已发了 Claude Code
// 风格的 system prompt。原因第三方工具opencode 等)会发 "You are Claude
// Code..." system prompt 但缺少 billing attribution block导致 Anthropic
// 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。
// Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
if !strings.Contains(strings.ToLower(reqModel), "haiku") {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
@@ -4017,6 +4182,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
// D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
// 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
c.Set(toolNameRewriteKey, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
@@ -4955,7 +5132,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
}
if !clientDisconnected {
if _, err := io.WriteString(w, line); err != nil {
restored := string(reverseToolNamesIfPresent(c, []byte(line)))
if _, err := io.WriteString(w, restored); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
@@ -5125,6 +5303,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
if contentType == "" {
contentType = "application/json"
}
body = reverseToolNamesIfPresent(c, body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
@@ -5580,13 +5759,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
setHeaderRaw(req.Header, "x-api-key", token)
}
// 白名单透传headers(恢复真实 wire casing
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
// 白名单透传 headers
// OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。
// Parrot 的 build_upstream_headers 只发 9 个精确 header不透传任何客户端 header。
// 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent /
// x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。
if tokenType != "oauth" || !mimicClaudeCode {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
}
}
}
}
@@ -5627,7 +5812,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
@@ -6099,6 +6284,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if isStream {
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
}
// Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。
// 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。
if getHeaderRaw(req.Header, "x-client-request-id") == "" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString())
}
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -6864,7 +7054,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for _, block := range outputBlocks {
if !clientDisconnected {
if _, werr := fmt.Fprint(w, block); werr != nil {
restored := reverseToolNamesIfPresent(c, []byte(block))
if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
break
@@ -7206,6 +7397,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
body = reverseToolNamesIfPresent(c, body)
// 写入响应
c.Data(resp.StatusCode, contentType, body)
@@ -8194,12 +8387,19 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
shouldMimicClaudeCode := account.IsOAuth()
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// Antigravity 账户不支持 count_tokens返回 404 让客户端 fallback 到本地估算。
@@ -8623,7 +8823,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")

View File

@@ -0,0 +1,313 @@
package service
import (
"fmt"
"hash/fnv"
"math/rand"
"sort"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
const toolNameRewriteKey = "claude_tool_name_rewrite"
// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
var staticToolNameRewrites = map[string]string{
"sessions_": "cc_sess_",
"session_": "cc_ses_",
}
// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
var fakeToolNamePrefixes = []string{
"analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
"process_", "query_", "render_", "resolve_", "sync_", "update_",
"validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
"review_", "search_", "transform_", "handle_", "invoke_", "notify_",
}
// dynamicToolMapThreshold 与 Parrot 一致tools 数量超过 5 才启用动态映射。
// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
const dynamicToolMapThreshold = 5
// ToolNameRewrite 是单次请求内的工具名混淆映射。
// - Forward: real → fake请求阶段在 body 上应用。
// - Reverse: fake → real响应阶段对每个 chunk 做 bytes.Replace 还原。
//
// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
type ToolNameRewrite struct {
Forward map[string]string
Reverse map[string]string
ReverseOrdered [][2]string
}
// buildDynamicToolMap 构造 tools 的动态假名映射。
//
// 与 Parrot _build_dynamic_tool_map 语义等价:
// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil不做动态映射走静态 fallback
// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
//
// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
// Go 无法字节级复刻 Python hash但"稳定性"和"前缀池打散"两个不变量都保留:
// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
// 字节级不同不影响上游判定Anthropic 不会验证我们的随机种子算法)。
func buildDynamicToolMap(toolNames []string) map[string]string {
if len(toolNames) <= dynamicToolMapThreshold {
return nil
}
h := fnv.New64a()
for i, n := range toolNames {
if i > 0 {
_, _ = h.Write([]byte{0})
}
_, _ = h.Write([]byte(n))
}
rng := rand.New(rand.NewSource(int64(h.Sum64())))
available := make([]string, len(fakeToolNamePrefixes))
copy(available, fakeToolNamePrefixes)
rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
mapping := make(map[string]string, len(toolNames))
for i, name := range toolNames {
prefix := available[i%len(available)]
headLen := 3
if len(name) < 3 {
headLen = len(name)
}
fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
mapping[name] = fake
}
return mapping
}
// sanitizeToolName 把真名转成假名。
// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
func sanitizeToolName(name string, dynamic map[string]string) string {
if dynamic != nil {
if fake, ok := dynamic[name]; ok {
return fake
}
}
for prefix, replacement := range staticToolNameRewrites {
if strings.HasPrefix(name, prefix) {
return replacement + name[len(prefix):]
}
}
return name
}
// shouldMimicToolName 指示某个 tool 是否需要重命名。
// server tooltype != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
func shouldMimicToolName(toolType string) bool {
if toolType == "" || toolType == "function" || toolType == "custom" {
return true
}
return false
}
// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name构造 ToolNameRewrite
// 并返回它。若不需要混淆tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
//
// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return nil
}
mimicableNames := make([]string, 0)
toolsArr := tools.Array()
for _, t := range toolsArr {
if !shouldMimicToolName(t.Get("type").String()) {
continue
}
name := t.Get("name").String()
if name == "" {
continue
}
mimicableNames = append(mimicableNames, name)
}
dynamic := buildDynamicToolMap(mimicableNames)
rw := &ToolNameRewrite{
Forward: make(map[string]string),
Reverse: make(map[string]string),
}
for _, name := range mimicableNames {
fake := sanitizeToolName(name, dynamic)
if fake == name {
continue
}
rw.Forward[name] = fake
rw.Reverse[fake] = name
}
if len(rw.Forward) == 0 {
return nil
}
rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
for fake, real := range rw.Reverse {
rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
}
sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
})
return rw
}
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
// - 改写 $.tools[*].name仅对 shouldMimicToolName 通过的 tool
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点Parrot 行为对齐,
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL
// - 改写 $.tool_choice.name仅当 $.tool_choice.type == "tool"
//
// 历史 $.messages[*].content[*].nametool_use不在请求侧改写——这与 Parrot 一致;
// 响应侧 bytes.Replace 会连带还原它们。
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
if rw == nil || len(rw.Forward) == 0 {
body = applyToolsLastCacheBreakpoint(body)
return body
}
tools := gjson.GetBytes(body, "tools")
if tools.IsArray() {
idx := -1
tools.ForEach(func(_, t gjson.Result) bool {
idx++
if !shouldMimicToolName(t.Get("type").String()) {
return true
}
name := t.Get("name").String()
if name == "" {
return true
}
fake, ok := rw.Forward[name]
if !ok {
return true
}
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
body = next
}
return true
})
}
if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
name := tc.Get("name").String()
if fake, ok := rw.Forward[name]; ok {
if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
body = next
}
}
}
body = applyToolsLastCacheBreakpoint(body)
return body
}
// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
// 行为,但 ttl 按本仓规则:
// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 纯副作用函数tools 不存在或为空数组时 no-op。
func applyToolsLastCacheBreakpoint(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return body
}
arr := tools.Array()
if len(arr) == 0 {
return body
}
lastIdx := len(arr) - 1
existingCC := arr[lastIdx].Get("cache_control")
if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
return body
}
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
body = next
}
return body
}
// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace防止子串冲突
// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
// 再做静态前缀还原cc_sess_ → sessions_ / cc_ses_ → session_
//
// rw 可为 nilnil 时仍会做静态前缀还原。
func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
if rw != nil {
for _, pair := range rw.ReverseOrdered {
fake, real := pair[0], pair[1]
if fake == "" || fake == real {
continue
}
data = replaceAllBytes(data, fake, real)
}
}
for prefix, replacement := range staticToolNameRewrites {
data = replaceAllBytes(data, replacement, prefix)
}
return data
}
// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
func replaceAllBytes(data []byte, from, to string) []byte {
if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
return data
}
return []byte(strings.ReplaceAll(string(data), from, to))
}
// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
// 找不到c==nil 或 key 不存在或类型不对)时返回 nil调用方必须能处理 nil。
func toolNameRewriteFromContext(c interface {
Get(string) (any, bool)
}) *ToolNameRewrite {
if c == nil {
return nil
}
raw, ok := c.Get(toolNameRewriteKey)
if !ok || raw == nil {
return nil
}
rw, _ := raw.(*ToolNameRewrite)
return rw
}
// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
func reverseToolNamesIfPresent(c interface {
Get(string) (any, bool)
}, chunk []byte) []byte {
rw := toolNameRewriteFromContext(c)
if rw == nil && len(staticToolNameRewrites) == 0 {
return chunk
}
return restoreToolNamesInBytes(chunk, rw)
}

View File

@@ -0,0 +1,185 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
// Parrot 行为tools 数量 ≤ 5 时不做动态映射。
names := []string{"bash", "edit", "read", "write", "search"}
require.Nil(t, buildDynamicToolMap(names))
}
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
a := buildDynamicToolMap(names)
b := buildDynamicToolMap(names)
require.NotNil(t, a)
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
require.Len(t, a, 6)
for _, name := range names {
require.Contains(t, a, name)
require.NotEqual(t, name, a[name])
}
}
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
require.Equal(t, "bash", sanitizeToolName("bash", nil))
}
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
dyn := map[string]string{"sessions_list": "analyze_ses00"}
got := sanitizeToolName("sessions_list", dyn)
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
}
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
rw := &ToolNameRewrite{
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
}
// 手工构造 ReverseOrdered长的在前
rw.ReverseOrdered = [][2]string{
{"abc_12_ext", "bar"},
{"abc_12", "foo"},
}
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
restored := string(restoreToolNamesInBytes(data, rw))
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
}
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
got := string(restoreToolNamesInBytes(data, nil))
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
}
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Contains(t, rw.Forward, "sessions_list")
require.Contains(t, rw.Forward, "session_get")
// web_search is a server tool, not rewritten
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
// tool_choice.name rewritten
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
// First tool untouched
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
}
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
out := applyToolsLastCacheBreakpoint(body)
// User-provided ttl must be preserved.
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
}
func TestStripMessageCacheControl(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
out := stripMessageCacheControl(body)
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := addMessageCacheBreakpoints(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
// Parrot 不变量messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
body := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"q1"}]},
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
{"role":"user","content":[{"type":"text","text":"q2"}]},
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
]}`)
out := addMessageCacheBreakpoints(body)
// 最后一条 assistant 被打断点
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
// 倒数第二个 user turn = index 0唯一另一个 user
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
// 其他不打断点
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
out := addMessageCacheBreakpoints(body)
// content 升级成数组
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
body := []byte(`{"tools":[
{"name":"t1","input_schema":{}},
{"name":"t2","input_schema":{}},
{"name":"t3","input_schema":{}},
{"name":"t4","input_schema":{}},
{"name":"t5","input_schema":{}},
{"name":"t6","input_schema":{}}
]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.NotEmpty(t, rw.ReverseOrdered)
for i := 1; i < len(rw.ReverseOrdered); i++ {
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
"ReverseOrdered must be sorted by fake-name length descending")
}
}
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
data := []byte("plain text without any tool names")
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
}
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
m := buildDynamicToolMap(names)
require.NotNil(t, m)
for _, name := range names {
fake, ok := m[name]
require.True(t, ok)
// fake = prefix + head3 + "%02d"
// ends with two decimal digits
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
head := name
if len(head) > 3 {
head = head[:3]
}
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
}
}

View File

@@ -26,7 +26,7 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.1.22 (external, cli)",
UserAgent: "claude-cli/2.1.92 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",

View File

@@ -3,7 +3,6 @@ package service
import (
"container/heap"
"context"
"errors"
"fmt"
"hash/fnv"
"math"
@@ -45,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
RequiredImageCapability OpenAIImagesCapability
RequireCompact bool
ExcludedIDs map[int64]struct{}
}
@@ -258,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select(
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
req.RequireCompact,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
selection = nil
}
}
@@ -348,8 +352,8 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
@@ -590,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
// require_privacy_set: 获取分组信息
@@ -630,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
loadMap := map[int64]*AccountLoadInfo{}
@@ -640,45 +644,14 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
@@ -686,53 +659,183 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
// 时作为最后兜底snapshot 可能已陈旧)。
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
candidateCount := len(candidates)
loadSkew := 0.0
if len(candidates) > 0 {
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
}
topK := 0
if len(candidates) > 0 {
topK = s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
}
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || topK <= 0 {
return nil
}
groupTopK := topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
for _, candidate := range candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
}
} else {
selectionOrder = buildSelectionOrder(candidates)
}
if len(selectionOrder) == 0 {
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
}
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
@@ -742,17 +845,25 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
@@ -761,10 +872,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked)
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
@@ -905,8 +1016,9 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
}
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
@@ -917,13 +1029,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
excludedIDs map[int64]struct{},
requiredCapability OpenAIImagesCapability,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
if err == nil && selection != nil && selection.Account != nil {
return selection, decision, nil
}
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basicOAuth 账号)
if requiredCapability == OpenAIImagesCapabilityNative {
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic)
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
}
return selection, decision, err
}
@@ -937,6 +1049,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requiredImageCapability OpenAIImagesCapability,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
@@ -945,7 +1058,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
@@ -970,7 +1083,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
@@ -1008,6 +1121,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
RequiredImageCapability: requiredImageCapability,
RequireCompact: requireCompact,
ExcludedIDs: excludedIDs,
})
}

View File

@@ -0,0 +1,195 @@
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown
// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91001)
accounts := []Account{
{
ID: 71001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown
},
{
ID: 71002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": true}, // tier=2
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown")
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported
// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91002)
accounts := []Account{
{
ID: 71010,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
{
ID: 71011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false},
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.Error(t, err)
require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error")
require.Nil(t, selection)
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown
// 验证当没有"已知支持"账号时compact 请求会回退到"未探测"账号。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91003)
accounts := []Account{
{
ID: 71020,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false}, // tier=0
},
{
ID: 71021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown -> tier=1
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available")
}
// TestOpenAICompactSupportTier 验证 tier 分类逻辑。
func TestOpenAICompactSupportTier(t *testing.T) {
tests := []struct {
name string
account *Account
want int
}{
{name: "nil", account: nil, want: 0},
{name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0},
{name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1},
{name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2},
{name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0},
{name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2},
{name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := openAICompactSupportTier(tt.account); got != tt.want {
t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want)
}
})
}
}

View File

@@ -289,6 +289,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -343,6 +344,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -384,6 +386,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
@@ -445,6 +448,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -486,7 +490,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -540,7 +544,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -616,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -662,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -740,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -788,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -857,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -900,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.Error(t, err)
require.Nil(t, selection)
@@ -976,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -1014,7 +1025,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
@@ -1218,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)

View File

@@ -54,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)

View File

@@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"fmt"
"strings"
)
@@ -48,6 +49,8 @@ type codexTransformResult struct {
const (
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
codexSparkImageUnsupportedMarker = "<sub2api-codex-spark-image-unsupported>"
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
@@ -151,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if normalizeCodexToolChoice(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
@@ -165,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if applyInstructions(reqBody, isCodexCLI) {
result.Modified = true
}
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
result.Modified = true
}
// 续链场景保留 item_reference 与 id避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
input = normalizedInput
result.Modified = true
}
if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
input = normalizedInput
result.Modified = true
}
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
@@ -192,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
return result
}
func normalizeCodexToolChoice(reqBody map[string]any) bool {
choice, ok := reqBody["tool_choice"]
if !ok || choice == nil {
return false
}
choiceMap, ok := choice.(map[string]any)
if !ok {
return false
}
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
return false
}
reqBody["tool_choice"] = "auto"
return true
}
func codexToolsContainType(rawTools any, toolType string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(toolType) == "" {
return false
}
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
normalized = append(normalized, item)
continue
}
role, _ := m["role"].(string)
if strings.TrimSpace(role) != "tool" {
normalized = append(normalized, item)
continue
}
callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
callID = strings.TrimSpace(callID)
if callID == "" {
// Responses does not accept role:"tool". If no call id is available,
// preserve the text as a user message instead of sending invalid input.
fallback := make(map[string]any, len(m))
for key, value := range m {
fallback[key] = value
}
fallback["role"] = "user"
delete(fallback, "tool_call_id")
normalized = append(normalized, fallback)
modified = true
continue
}
output := extractTextFromContent(m["content"])
if output == "" {
if value, ok := m["output"].(string); ok {
output = value
}
}
if output == "" && m["content"] != nil {
if b, err := json.Marshal(m["content"]); err == nil {
output = string(b)
}
}
normalized = append(normalized, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
modified = true
}
if !modified {
return input, false
}
return normalized, true
}
func normalizeCodexMessageContentText(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
normalized = append(normalized, item)
continue
}
parts, ok := m["content"].([]any)
if !ok {
normalized = append(normalized, item)
continue
}
var newItem map[string]any
var newParts []any
ensureItemCopy := func() {
if newItem != nil {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
newParts = make([]any, len(parts))
copy(newParts, parts)
}
for i, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
text, hasText := part["text"]
if !hasText {
continue
}
if _, ok := text.(string); ok {
continue
}
ensureItemCopy()
newPart := make(map[string]any, len(part))
for key, value := range part {
newPart[key] = value
}
newPart["text"] = stringifyCodexContentText(text)
newParts[i] = newPart
modified = true
}
if newItem != nil {
newItem["content"] = newParts
normalized = append(normalized, newItem)
continue
}
normalized = append(normalized, item)
}
if !modified {
return input, false
}
return normalized, true
}
func stringifyCodexContentText(value any) string {
switch v := value.(type) {
case string:
return v
case nil:
return ""
default:
if b, err := json.Marshal(v); err == nil {
return string(b)
}
return fmt.Sprint(v)
}
}
func normalizeCodexModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
@@ -244,6 +438,10 @@ func normalizeCodexModel(model string) string {
return "gpt-5.4"
}
func isCodexSparkModel(model string) bool {
return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
}
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -265,6 +463,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
return false
}
func hasOpenAIInputImage(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
}
func hasOpenAIInputImageValue(value any) bool {
switch v := value.(type) {
case []any:
for _, item := range v {
if hasOpenAIInputImageValue(item) {
return true
}
}
case map[string]any:
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
return true
}
if _, ok := v["image_url"]; ok {
return true
}
return hasOpenAIInputImageValue(v["content"])
}
return false
}
func validateCodexSparkInput(reqBody map[string]any, model string) error {
if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
return nil
}
return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
}
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -309,6 +541,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
tool := map[string]any{
"type": "image_generation",
@@ -344,6 +579,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexImageGenerationBridgeMarker) {
@@ -360,6 +598,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
return true
}
func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
return false
}
existing = strings.TrimRight(existing, " \t\r\n")
if strings.TrimSpace(existing) == "" {
reqBody["instructions"] = codexSparkImageUnsupportedText
return true
}
reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
return true
}
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
if !hasOpenAIImageGenerationTool(reqBody) {
return nil
@@ -658,12 +913,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
}
if !isCodexToolCallItemType(typ) {
ensureCopy()
delete(newItem, "call_id")
}
if codexInputItemRequiresName(typ) {
if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
name := firstNonEmptyString(m["tool_name"])
if name == "" {
if function, ok := m["function"].(map[string]any); ok {
name = firstNonEmptyString(function["name"])
}
}
if name == "" {
name = "tool"
}
ensureCopy()
newItem["name"] = name
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
}
filtered = append(filtered, newItem)
@@ -672,10 +945,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
func isCodexToolCallItemType(typ string) bool {
if typ == "" {
switch typ {
case "function_call",
"tool_call",
"local_shell_call",
"tool_search_call",
"custom_tool_call",
"mcp_tool_call",
"function_call_output",
"mcp_tool_call_output",
"custom_tool_call_output",
"tool_search_output":
return true
default:
return false
}
}
func codexInputItemRequiresName(typ string) bool {
switch strings.TrimSpace(typ) {
case "function_call", "custom_tool_call", "mcp_tool_call":
return true
default:
return false
}
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {

View File

@@ -92,6 +92,235 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "tool_search_output", first["type"])
require.Equal(t, "fc1", first["call_id"])
}
func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "fccustom", first["call_id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fcmcp", second["call_id"])
}
func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "ig_123", first["id"])
_, hasCallID := first["call_id"]
require.False(t, hasCallID)
second, ok := input[1].(map[string]any)
require.True(t, ok)
_, hasCallID = second["call_id"]
require.False(t, hasCallID)
}
func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call_output", item["type"])
require.Equal(t, "fc1", item["call_id"])
require.Equal(t, "ok", item["output"])
_, hasRole := item["role"]
require.False(t, hasRole)
}
func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": []any{"a", "b"}},
},
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
item, ok := input[0].(map[string]any)
require.True(t, ok)
content, ok := item["content"].([]any)
require.True(t, ok)
part, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, `["a","b"]`, part["text"])
}
func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "custom", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "custom", choice["type"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "message", "role": "user", "content": "run tool"},
map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
item, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call", item["type"])
require.Equal(t, "tool", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "shell", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "mcp_tool_call",
"call_id": "call_abc",
"name": "remote_tool",
"arguments": "{}",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "mcp_tool_call", item["type"])
require.Equal(t, "remote_tool", item["name"])
require.Equal(t, "fcabc", item["call_id"])
}
func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
require.True(t, codexInputItemRequiresName(typ), typ)
require.True(t, isCodexToolCallItemType(typ), typ)
}
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true保持 false。
@@ -261,6 +490,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
require.Equal(t, "png", tool["output_format"])
}
func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": "draw a cat",
}
modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
require.False(t, modified)
require.NotContains(t, reqBody, "tools")
}
func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
@@ -306,6 +546,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *
func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
@@ -325,6 +566,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin
require.False(t, modified)
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
},
}
modified := applyCodexImageGenerationBridgeInstructions(reqBody)
require.False(t, modified)
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
reqBody := map[string]any{
"instructions": "existing instructions",
@@ -338,6 +593,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "describe"},
map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
require.Contains(t, err.Error(), "does not support image input")
}
func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
}
func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "hello"},
},
},
},
}
require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
}
func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"input": "hello",
}
result := applyCodexOAuthTransform(reqBody, true, false)
require.True(t, result.Modified)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Contains(t, instructions, "existing instructions")
require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
require.Contains(t, instructions, "does not support image generation")
require.Contains(t, instructions, "switch to a non-Spark Codex model")
require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
}
func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"input": "hello",
}
applyCodexOAuthTransform(reqBody, true, false)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
}
func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-image-2",

View File

@@ -0,0 +1,135 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4", result.UpstreamModel)
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 3,
Name: "openai-oauth-pass",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String())
}

View File

@@ -0,0 +1,120 @@
package service
import (
"net/http"
"strconv"
"strings"
"time"
)
const (
// AccountTestModeDefault drives the standard /responses connection test.
AccountTestModeDefault = "default"
// AccountTestModeCompact drives the /responses/compact compact-probe test.
AccountTestModeCompact = "compact"
)
func normalizeAccountTestMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case AccountTestModeCompact:
return AccountTestModeCompact
default:
return AccountTestModeDefault
}
}
func createOpenAICompactProbePayload(model string) map[string]any {
return map[string]any{
"model": strings.TrimSpace(model),
"instructions": "You are a helpful coding assistant.",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": "Respond with OK.",
},
},
}
}
func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool {
switch status {
case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
return true
case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity:
lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body)))
if strings.Contains(lower, "compact") {
for _, keyword := range []string{
"unsupported",
"not support",
"does not support",
"not available",
"disabled",
} {
if strings.Contains(lower, keyword) {
return true
}
}
}
}
return false
}
func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any {
updates := map[string]any{
"openai_compact_checked_at": now.Format(time.RFC3339),
"openai_compact_last_status": nil,
}
if resp != nil {
updates["openai_compact_last_status"] = resp.StatusCode
}
switch {
case probeErr != nil:
updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048)
case resp == nil:
updates["openai_compact_last_error"] = "compact probe failed"
default:
errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if errMsg == "" && len(body) > 0 {
errMsg = strings.TrimSpace(string(body))
}
if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
errMsg = "HTTP " + strconv.Itoa(resp.StatusCode)
}
errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
updates["openai_compact_supported"] = true
updates["openai_compact_last_error"] = ""
} else {
if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) {
updates["openai_compact_supported"] = false
}
updates["openai_compact_last_error"] = errMsg
}
}
return updates
}
func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any {
if len(base) == 0 && len(more) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(more))
for key, value := range base {
out[key] = value
}
for key, value := range more {
out[key] = value
}
return out
}
func compactProbeSessionID(accountID int64) string {
if accountID <= 0 {
return "probe_compact"
}
return "probe_compact_" + strconv.FormatInt(accountID, 10)
}

View File

@@ -0,0 +1,122 @@
package service
import (
"errors"
"net/http"
"testing"
"time"
)
func TestNormalizeAccountTestMode(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "", want: AccountTestModeDefault},
{input: "default", want: AccountTestModeDefault},
{input: " compact ", want: AccountTestModeCompact},
{input: "COMPACT", want: AccountTestModeCompact},
{input: "unknown", want: AccountTestModeDefault},
}
for _, tt := range tests {
if got := normalizeAccountTestMode(tt.input); got != tt.want {
t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now)
if got := updates["openai_compact_supported"]; got != true {
t.Fatalf("openai_compact_supported = %v, want true", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusOK {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK)
}
if got := updates["openai_compact_last_error"]; got != "" {
t.Fatalf("openai_compact_last_error = %v, want empty string", got)
}
if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) {
t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339))
}
}
func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`404 page not found`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now)
if got := updates["openai_compact_supported"]; got != false {
t.Fatalf("openai_compact_supported = %v, want false", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusNotFound {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for 502 response")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for request error")
}
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got == "" {
t.Fatalf("expected openai_compact_last_error to be populated")
}
}
func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now)
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got != "compact probe failed" {
t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now)
if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable)
}
if got := updates["openai_compact_last_error"]; got != "HTTP 503" {
t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type errReadCloser struct {
err error
}
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
func (r errReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{err: io.ErrUnexpectedEOF},
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr))
require.True(t, c.Writer.Written())
require.Contains(t, rec.Body.String(), "response.failed")
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1104,16 +1255,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
}
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1139,7 +1326,7 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)

View File

@@ -1,5 +1,7 @@
package service
import "strings"
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
@@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
return defaultMappedModel
}
return mappedModel
}
func isExplicitCodexModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
model = strings.ToLower(strings.TrimSpace(model))
if getNormalizedCodexModel(model) != "" {
return true
}
if strings.HasSuffix(model, "-openai-compact") {
base := strings.TrimSuffix(model, "-openai-compact")
return getNormalizedCodexModel(base) != ""
}
return false
}
// resolveOpenAICompactForwardModel determines the compact-only upstream model
// for /responses/compact requests. It never affects normal /responses traffic.
// When no compact-specific mapping matches, the input model is returned as-is.
func resolveOpenAICompactForwardModel(account *Account, model string) string {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" || account == nil {
return trimmedModel
}
mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel)
if !matched {
return trimmedModel
}
if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" {
return trimmedMapped
}
return trimmedModel
}

View File

@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
requestedModel: "claude-opus-4-6",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves explicit gpt-5.4 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves codex spark instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.3-codex-spark",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.3-codex-spark",
},
{
name: "preserves gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5",
},
{
name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "openai/gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "openai/gpt-5.5",
},
{
name: "preserves compact gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5-openai-compact",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5-openai-compact",
},
}
for _, tt := range tests {
@@ -85,6 +130,74 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
}
}
func TestResolveOpenAICompactForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
model string
expectedModel string
}{
{
name: "nil account keeps original model",
account: nil,
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "missing compact mapping keeps original model",
account: &Account{
Credentials: map[string]any{},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "exact compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
},
{
name: "wildcard compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.*": "gpt-5-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5-openai-compact",
},
{
name: "passthrough compact mapping remains unchanged",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel {
t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",

View File

@@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {

View File

@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 满足以下任一信号即视为续链previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
return true
}
}

View File

@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
{name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
{name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},

View File

@@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
@@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false)
require.NoError(t, err)
require.Nil(t, selection)
}
@@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
@@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)

View File

@@ -3800,6 +3800,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
requireCompact bool,
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
@@ -3840,11 +3841,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
if requireCompact && openAICompactSupportTier(account) == 0 {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -268,6 +269,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
s.applyAffiliateRebateForOrder(ctx, o)
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
@@ -281,6 +283,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
s.applyAffiliateRebateForOrder(ctx, o)
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
@@ -358,6 +361,139 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) {
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
return
}
if s.affiliateService == nil {
return
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
})
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return
}
if !claimed {
return
}
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return
}
if rebateAmount <= 0 {
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
"baseAmount": o.Amount,
"reason": "no inviter bound or rebate amount <= 0",
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
}
return
}
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
"baseAmount": o.Amount,
"rebateAmount": rebateAmount,
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
}
}
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
if client == nil {
return false, errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detail, _ := json.Marshal(map[string]any{
"baseAmount": baseAmount,
"status": "reserved",
})
rows, err := client.QueryContext(ctx, `
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW()
WHERE NOT EXISTS (
SELECT 1
FROM payment_audit_logs
WHERE order_id = $1
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
)
ON CONFLICT (order_id, action) DO NOTHING
RETURNING id`, oid, string(detail))
if err != nil {
return false, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return false, err
}
return false, nil
}
var claimID int64
if err := rows.Scan(&claimID); err != nil {
return false, err
}
return true, nil
}
func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
if client == nil {
return errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detailJSON, _ := json.Marshal(detail)
updated, err := client.PaymentAuditLog.Update().
Where(
paymentauditlog.OrderIDEQ(oid),
paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
).
SetAction(action).
SetDetail(string(detailJSON)).
SetOperator("system").
Save(ctx)
if err != nil {
return err
}
if updated == 0 {
return errors.New("affiliate rebate claim log not found")
}
return nil
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)

View File

@@ -170,21 +170,22 @@ type TopUserStat struct {
// --- Service ---
type PaymentService struct {
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
affiliateService *AffiliateService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}

View File

@@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
@@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
return calculateOpenAI429ResetTime(headers)
}
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net/url"
"sort"
"strconv"
@@ -1167,6 +1168,8 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
@@ -1719,6 +1722,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
@@ -1846,6 +1850,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
} else {
result.AffiliateRebateRate = AffiliateRebateRateDefault
}
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -2130,6 +2139,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result
}
func clampAffiliateRebateRate(value float64) float64 {
if math.IsNaN(value) || math.IsInf(value, 0) {
return AffiliateRebateRateDefault
}
if value < AffiliateRebateRateMin {
return AffiliateRebateRateMin
}
if value > AffiliateRebateRateMax {
return AffiliateRebateRateMax
}
return value
}
func isFalseSettingValue(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":

View File

@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
AffiliateRebateRate float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting

View File

@@ -486,6 +486,7 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,
ProvidePaymentOrderExpiryService,

View File

@@ -0,0 +1,20 @@
CREATE TABLE IF NOT EXISTS user_affiliates (
user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
aff_code VARCHAR(32) NOT NULL UNIQUE,
inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
aff_count INTEGER NOT NULL DEFAULT 0,
aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';

View File

@@ -0,0 +1,58 @@
-- 1) Normalize historical affiliate rebate rate values.
-- Legacy compatibility treated 0<x<=1 as fractional inputs (e.g. 0.2 => 20%).
-- We now use pure percentage semantics, so convert persisted fractional values once.
UPDATE settings
SET value = to_char((value::numeric * 100), 'FM999999990.########'),
updated_at = NOW()
WHERE key = 'affiliate_rebate_rate'
AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
AND value::numeric > 0
AND value::numeric <= 1;
-- 2) Affiliate ledger for accrual/transfer traceability.
CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
action VARCHAR(32) NOT NULL,
amount DECIMAL(20,8) NOT NULL,
source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
-- 3) Enforce idempotency at DB layer for payment audit actions.
WITH ranked AS (
SELECT id,
ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
FROM payment_audit_logs
)
DELETE FROM payment_audit_logs p
USING ranked r
WHERE p.id = r.id
AND r.rn > 1;
CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
ON payment_audit_logs(order_id, action);
-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
SELECT po.id::text,
'AFFILIATE_REBATE_SKIPPED',
'{"reason":"baseline before affiliate rebate idempotency rollout"}',
'system',
NOW()
FROM payment_orders po
WHERE po.order_type = 'balance'
AND po.status = 'COMPLETED'
AND NOT EXISTS (
SELECT 1
FROM payment_audit_logs pal
WHERE pal.order_id = po.id::text
AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
);

View File

@@ -308,6 +308,7 @@ export interface SystemSettings {
totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
default_balance: number;
affiliate_rebate_rate: number;
default_concurrency: number;
default_user_rpm_limit: number;
default_subscriptions: DefaultSubscriptionSetting[];
@@ -489,6 +490,7 @@ export interface UpdateSettingsRequest {
invitation_code_enabled?: boolean;
totp_enabled?: boolean; // TOTP 双因素认证
default_balance?: number;
affiliate_rebate_rate?: number;
default_concurrency?: number;
default_user_rpm_limit?: number;
default_subscriptions?: DefaultSubscriptionSetting[];

View File

@@ -9,7 +9,14 @@ import {
prepareOAuthBindAccessTokenCookie,
type WeChatOAuthPublicSettings,
} from './auth'
import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types'
import type {
User,
ChangePasswordRequest,
NotifyEmailEntry,
UserAuthProvider,
UserAffiliateDetail,
AffiliateTransferResponse
} from '@/types'
/**
* Get current user profile
@@ -168,6 +175,16 @@ export async function startOAuthBinding(
window.location.href = startURL
}
export async function getAffiliateDetail(): Promise<UserAffiliateDetail> {
const { data } = await apiClient.get<UserAffiliateDetail>('/user/aff')
return data
}
export async function transferAffiliateQuota(): Promise<AffiliateTransferResponse> {
const { data } = await apiClient.post<AffiliateTransferResponse>('/user/aff/transfer')
return data
}
export const userAPI = {
getProfile,
updateProfile,
@@ -180,7 +197,9 @@ export const userAPI = {
bindEmailIdentity,
unbindAuthIdentity,
buildOAuthBindingStartURL,
startOAuthBinding
startOAuthBinding,
getAffiliateDetail,
transferAffiliateQuota
}
export default userAPI

View File

@@ -55,6 +55,17 @@
/>
</div>
<div v-if="isOpenAIAccount" class="space-y-1.5">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.openai.testMode') }}
</label>
<Select
v-model="testMode"
:options="openAITestModeOptions"
:disabled="status === 'connecting'"
/>
</div>
<div v-if="supportsImageTest" class="space-y-1.5">
<TextArea
v-model="testPrompt"
@@ -274,6 +285,12 @@ const testPrompt = ref('')
const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref<PreviewImage[]>([])
const testMode = ref<'default' | 'compact'>('default')
const isOpenAIAccount = computed(() => props.account?.platform === 'openai')
const openAITestModeOptions = computed(() => [
{ value: 'default', label: t('admin.accounts.openai.testModeDefault') },
{ value: 'compact', label: t('admin.accounts.openai.testModeCompact') }
])
const previewImageUrl = ref('')
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
@@ -308,6 +325,7 @@ watch(
async (newVal) => {
if (newVal && props.account) {
testPrompt.value = ''
testMode.value = 'default'
resetState()
await loadAvailableModels()
} else {
@@ -410,9 +428,10 @@ const startTest = async () => {
'Content-Type': 'application/json'
},
body: JSON.stringify({
model_id: selectedModelId.value,
prompt: supportsImageTest.value ? testPrompt.value.trim() : ''
}),
model_id: selectedModelId.value,
prompt: supportsImageTest.value ? testPrompt.value.trim() : '',
mode: isOpenAIAccount.value ? testMode.value : 'default'
}),
signal: abortController.signal
})

View File

@@ -2449,6 +2449,45 @@
</div>
</div>
<!-- OpenAI Compact 能力配置 -->
<div
v-if="form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')"
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
>
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.openai.compactMode') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.compactModeDesc') }}
</p>
</div>
<div class="w-44">
<Select v-model="openAICompactMode" :options="openAICompactModeOptions" />
</div>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.openai.compactModelMapping') }}</label>
<p class="input-hint">{{ t('admin.accounts.openai.compactModelMappingDesc') }}</p>
<div v-if="openAICompactModelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in openAICompactModelMappings"
:key="getOpenAICompactModelMappingKey(mapping)"
class="flex items-center gap-2"
>
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
<span class="text-gray-400"></span>
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
<button type="button" @click="removeOpenAICompactModelMapping(index)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
</div>
<button type="button" @click="addOpenAICompactModelMapping" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
</div>
</div>
<div>
<div class="flex items-center justify-between">
<div>
@@ -2918,7 +2957,8 @@ import type {
AccountPlatform,
AccountType,
CheckMixedChannelResponse,
CreateAccountRequest
CreateAccountRequest,
OpenAICompactMode
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
@@ -3059,6 +3099,7 @@ const editWeeklyResetDay = ref<number | null>(null)
const editWeeklyResetHour = ref<number | null>(null)
const editResetTimezone = ref<string | null>(null)
const modelMappings = ref<ModelMapping[]>([])
const openAICompactModelMappings = ref<ModelMapping[]>([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -3071,6 +3112,7 @@ const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -3112,10 +3154,16 @@ const bedrockApiKeyValue = ref('')
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
const openAICompactModeOptions = computed(() => [
{ value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
])
function buildAntigravityExtra(): Record<string, unknown> | undefined {
const extra: Record<string, unknown> = {}
@@ -3124,6 +3172,9 @@ function buildAntigravityExtra(): Record<string, unknown> | undefined {
return Object.keys(extra).length > 0 ? extra : undefined
}
const buildOpenAICompactModelMapping = () =>
buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
@@ -3489,6 +3540,14 @@ const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
}
const addOpenAICompactModelMapping = () => {
openAICompactModelMappings.value.push({ from: '', to: '' })
}
const removeOpenAICompactModelMapping = (index: number) => {
openAICompactModelMappings.value.splice(index, 1)
}
const removeModelMapping = (index: number) => {
modelMappings.value.splice(index, 1)
}
@@ -3781,6 +3840,7 @@ const resetForm = () => {
editWeeklyResetHour.value = null
editResetTimezone.value = null
modelMappings.value = []
openAICompactModelMappings.value = []
modelRestrictionMode.value = 'whitelist'
allowedModels.value = [...claudeModels] // Default fill related models
@@ -3797,6 +3857,7 @@ const resetForm = () => {
interceptWarmupRequests.value = false
autoPauseOnExpired.value = true
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -3874,6 +3935,11 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
} else {
delete extra.codex_cli_only
}
if (openAICompactMode.value !== 'auto') {
extra.openai_compact_mode = openAICompactMode.value
} else {
delete extra.openai_compact_mode
}
return Object.keys(extra).length > 0 ? extra : undefined
}
@@ -4086,6 +4152,12 @@ const handleSubmit = async () => {
credentials.model_mapping = modelMapping
}
}
if (form.platform === 'openai') {
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping
}
}
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -4198,6 +4270,14 @@ const createAccountAndFinish = async (
finalExtra = quotaExtra
}
}
if (platform === 'openai') {
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping
} else {
delete credentials.compact_model_mapping
}
}
await doCreateAccount({
name: form.name,
notes: form.notes,
@@ -4252,6 +4332,12 @@ const handleOpenAIExchange = async (authCode: string) => {
credentials.model_mapping = modelMapping
}
}
if (shouldCreateOpenAI) {
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping
}
}
// 应用临时不可调度配置
if (!applyTempUnschedConfig(credentials)) {
@@ -4344,6 +4430,12 @@ const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string)
credentials.model_mapping = modelMapping
}
}
if (shouldCreateOpenAI) {
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping
}
}
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'

View File

@@ -1306,6 +1306,64 @@
</div>
</div>
<div
v-if="account?.platform === 'openai' && (account?.type === 'oauth' || account?.type === 'apikey')"
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
>
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.openai.compactMode') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.compactModeDesc') }}
</p>
</div>
<div class="w-44">
<Select v-model="openAICompactMode" :options="openAICompactModeOptions" />
</div>
</div>
<div class="rounded-lg bg-gray-50 px-3 py-2 text-xs text-gray-600 dark:bg-dark-700 dark:text-gray-300">
<span class="font-medium">{{ t(openAICompactStatusKey) }}</span>
<span
v-if="account?.extra?.openai_compact_checked_at"
class="ml-2 text-gray-500 dark:text-gray-400"
>
{{ t('admin.accounts.openai.compactLastChecked') }}:
{{ formatDateTime(new Date(String(account.extra.openai_compact_checked_at))) }}
</span>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.openai.compactModelMapping') }}</label>
<p class="input-hint">{{ t('admin.accounts.openai.compactModelMappingDesc') }}</p>
<div v-if="openAICompactModelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in openAICompactModelMappings"
:key="getOpenAICompactModelMappingKey(mapping)"
class="flex items-center gap-2"
>
<input
v-model="mapping.from"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.fromModel')"
/>
<span class="text-gray-400"></span>
<input
v-model="mapping.to"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.toModel')"
/>
<button type="button" @click="removeOpenAICompactModelMapping(index)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
</div>
<button type="button" @click="addOpenAICompactModelMapping" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
</div>
</div>
<div>
<div class="flex items-center justify-between">
<div>
@@ -1849,7 +1907,7 @@ import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@@ -1859,7 +1917,7 @@ import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
OPENAI_WS_MODE_CTX_POOL,
@@ -1934,6 +1992,7 @@ const isBedrockAPIKeyMode = computed(() =>
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
)
const modelMappings = ref<ModelMapping[]>([])
const openAICompactModelMappings = ref<ModelMapping[]>([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -1953,6 +2012,7 @@ const antigravityModelMappings = ref<ModelMapping[]>([])
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-model-mapping')
const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
@@ -1992,6 +2052,7 @@ const customBaseUrl = ref('')
// OpenAI 自动透传开关OAuth/API Key
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -2045,9 +2106,27 @@ const openaiResponsesWebSocketV2Mode = computed({
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
)
const openAICompactModeOptions = computed(() => [
{ value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
])
const isOpenAIModelRestrictionDisabled = computed(() =>
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
)
const openAICompactStatusKey = computed(() => {
const extra = props.account?.extra as Record<string, unknown> | undefined
if (!props.account || props.account.platform !== 'openai') return ''
const mode = typeof extra?.openai_compact_mode === 'string' ? extra.openai_compact_mode : 'auto'
if (mode === 'force_on') return 'admin.accounts.openai.compactSupported'
if (mode === 'force_off') return 'admin.accounts.openai.compactUnsupported'
if (typeof extra?.openai_compact_supported === 'boolean') {
return extra.openai_compact_supported
? 'admin.accounts.openai.compactSupported'
: 'admin.accounts.openai.compactUnsupported'
}
return 'admin.accounts.openai.compactUnknown'
})
// Computed: current preset mappings based on platform
const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic'))
@@ -2177,6 +2256,8 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
openAICompactModelMappings.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -2184,6 +2265,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
webSearchEmulationMode.value = 'default'
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
modeKey: 'openai_oauth_responses_websockets_v2_mode',
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
@@ -2199,6 +2281,11 @@ const syncFormFromAccount = (newAccount: Account | null) => {
if (newAccount.type === 'oauth') {
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
}
const credentials = newAccount.credentials as Record<string, unknown> | undefined
const compactMappings = credentials?.compact_model_mapping as Record<string, string> | undefined
if (compactMappings && typeof compactMappings === 'object') {
openAICompactModelMappings.value = Object.entries(compactMappings).map(([from, to]) => ({ from, to }))
}
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
@@ -2423,6 +2510,15 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editApiKey.value = ''
}
async function loadTLSProfiles() {
try {
const profiles = await adminAPI.tlsFingerprintProfiles.list()
tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
} catch {
tlsFingerprintProfiles.value = []
}
}
watch(
[() => props.show, () => props.account],
([show, newAccount], [wasShow, previousAccount]) => {
@@ -2437,15 +2533,6 @@ watch(
{ immediate: true }
)
const loadTLSProfiles = async () => {
try {
const profiles = await adminAPI.tlsFingerprintProfiles.list()
tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
} catch {
tlsFingerprintProfiles.value = []
}
}
// Model mapping helpers
const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
@@ -2468,6 +2555,14 @@ const addAntigravityModelMapping = () => {
antigravityModelMappings.value.push({ from: '', to: '' })
}
const addOpenAICompactModelMapping = () => {
openAICompactModelMappings.value.push({ from: '', to: '' })
}
const removeOpenAICompactModelMapping = (index: number) => {
openAICompactModelMappings.value.splice(index, 1)
}
const removeAntigravityModelMapping = (index: number) => {
antigravityModelMappings.value.splice(index, 1)
}
@@ -2911,6 +3006,14 @@ const handleSubmit = async () => {
} else if (currentCredentials.model_mapping) {
newCredentials.model_mapping = currentCredentials.model_mapping
}
if (props.account.platform === 'openai') {
const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
if (compactModelMapping) {
newCredentials.compact_model_mapping = compactModelMapping
} else {
delete newCredentials.compact_model_mapping
}
}
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -3036,6 +3139,12 @@ const handleSubmit = async () => {
// 透传模式保留现有映射
newCredentials.model_mapping = currentCredentials.model_mapping
}
const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
if (compactModelMapping) {
newCredentials.compact_model_mapping = compactModelMapping
} else {
delete newCredentials.compact_model_mapping
}
updatePayload.credentials = newCredentials
}
@@ -3208,6 +3317,11 @@ const handleSubmit = async () => {
delete newExtra.openai_passthrough
delete newExtra.openai_oauth_passthrough
}
if (openAICompactMode.value === 'auto') {
delete newExtra.openai_compact_mode
} else {
newExtra.openai_compact_mode = openAICompactMode.value
}
if (props.account.type === 'oauth') {
if (codexCLIOnlyEnabled.value) {

View File

@@ -122,7 +122,7 @@ describe('AccountStatusIndicator', () => {
}
})
expect(wrapper.text()).toContain('account.creditsExhausted')
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
it('模型限流 + overages 启用 + AICredits key 生效 → 普通限流样式(积分耗尽,无 ⚡)', () => {
@@ -157,6 +157,6 @@ describe('AccountStatusIndicator', () => {
expect(wrapper.text()).toContain('CSon45')
expect(wrapper.text()).not.toContain('⚡')
// AICredits 积分耗尽状态应显示
expect(wrapper.text()).toContain('account.creditsExhausted')
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
})

View File

@@ -0,0 +1,150 @@
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
import { defineComponent } from 'vue'
import AccountTestModal from '../AccountTestModal.vue'
const { getAvailableModelsMock } = vi.hoisted(() => ({
getAvailableModelsMock: vi.fn()
}))
vi.mock('@/api/admin', () => ({
adminAPI: {
accounts: {
getAvailableModels: getAvailableModelsMock
}
}
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copyToClipboard: vi.fn()
})
}))
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
const BaseDialogStub = defineComponent({
name: 'BaseDialog',
props: { show: { type: Boolean, default: false } },
template: '<div v-if="show"><slot /><slot name="footer" /></div>'
})
const SelectStub = defineComponent({
name: 'SelectStub',
props: {
modelValue: { type: [String, Number, Boolean, null], default: '' },
options: { type: Array, default: () => [] },
valueKey: { type: String, default: 'value' },
labelKey: { type: String, default: 'label' }
},
emits: ['update:modelValue'],
template: `
<select
v-bind="$attrs"
:value="modelValue"
@change="$emit('update:modelValue', $event.target.value)"
>
<option
v-for="option in options"
:key="option[valueKey]"
:value="option[valueKey]"
>
{{ option[labelKey] }}
</option>
</select>
`
})
const TextAreaStub = defineComponent({
name: 'TextArea',
props: {
modelValue: { type: String, default: '' }
},
emits: ['update:modelValue'],
template: `
<textarea
v-bind="$attrs"
:value="modelValue"
@input="$emit('update:modelValue', $event.target.value)"
/>
`
})
function buildAccount() {
return {
id: 1,
name: 'OpenAI OAuth',
platform: 'openai',
type: 'oauth',
status: 'active',
credentials: {},
extra: {},
concurrency: 1,
priority: 1,
proxy_id: null,
auto_pause_on_expired: false
} as any
}
describe('AccountTestModal', () => {
const originalFetch = global.fetch
beforeEach(() => {
getAvailableModelsMock.mockReset()
getAvailableModelsMock.mockResolvedValue([
{ id: 'gpt-5.4', display_name: 'GPT-5.4' }
])
global.fetch = vi.fn().mockResolvedValue({
ok: true,
body: {
getReader: () => ({
read: vi.fn().mockResolvedValue({ done: true, value: undefined })
})
}
} as any)
localStorage.setItem('auth_token', 'test-token')
})
afterEach(() => {
global.fetch = originalFetch
localStorage.clear()
})
it('posts compact mode for OpenAI compact probe', async () => {
const wrapper = mount(AccountTestModal, {
props: {
show: true,
account: buildAccount()
},
global: {
stubs: {
BaseDialog: BaseDialogStub,
Select: SelectStub,
TextArea: TextAreaStub,
Icon: true
}
}
})
await flushPromises()
;(wrapper.vm as any).selectedModelId = 'gpt-5.4'
;(wrapper.vm as any).testMode = 'compact'
await (wrapper.vm as any).startTest()
await flushPromises()
expect(global.fetch).toHaveBeenCalledTimes(1)
const [, options] = (global.fetch as any).mock.calls[0]
expect(JSON.parse(options.body)).toMatchObject({
model_id: 'gpt-5.4',
mode: 'compact'
})
})
})

View File

@@ -26,6 +26,13 @@ vi.mock('@/api/admin', () => ({
accounts: {
update: updateAccountMock,
checkMixedChannelRisk: checkMixedChannelRiskMock
},
settings: {
getWebSearchEmulationConfig: vi.fn().mockResolvedValue({ enabled: false, providers: [] }),
getSettings: vi.fn().mockResolvedValue({})
},
tlsFingerprintProfiles: {
list: vi.fn().mockResolvedValue([])
}
}
}))
@@ -82,6 +89,32 @@ const ModelWhitelistSelectorStub = defineComponent({
`
})
const SelectStub = defineComponent({
name: 'SelectStub',
props: {
modelValue: {
type: [String, Number, Boolean, null],
default: ''
},
options: {
type: Array,
default: () => []
}
},
emits: ['update:modelValue'],
template: `
<select
v-bind="$attrs"
:value="modelValue"
@change="$emit('update:modelValue', $event.target.value)"
>
<option v-for="option in options" :key="option.value" :value="option.value">
{{ option.label }}
</option>
</select>
`
})
function buildAccount() {
return {
id: 1,
@@ -119,7 +152,7 @@ function mountModal(account = buildAccount()) {
global: {
stubs: {
BaseDialog: BaseDialogStub,
Select: true,
Select: SelectStub,
Icon: true,
ProxySelector: true,
GroupSelector: true,
@@ -156,4 +189,31 @@ describe('EditAccountModal', () => {
'gpt-5.2': 'gpt-5.2'
})
})
it('submits OpenAI compact mode and compact-only model mapping', async () => {
const account = buildAccount()
account.extra = {
openai_compact_mode: 'force_on'
}
account.credentials = {
...account.credentials,
compact_model_mapping: {
'gpt-5.4': 'gpt-5.4-openai-compact'
}
}
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_compact_mode).toBe('force_on')
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.compact_model_mapping).toEqual({
'gpt-5.4': 'gpt-5.4-openai-compact'
})
})
})

View File

@@ -656,6 +656,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] {
{ path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
{ path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
path: `/custom/${item.id}`,

View File

@@ -33,7 +33,7 @@ function createOrderResult(overrides: Partial<CreateOrderResult> = {}): CreateOr
}
describe('getVisibleMethods', () => {
it('filters hidden provider methods and normalizes aliases', () => {
it('normalizes provider aliases and keeps stripe as a top-level method', () => {
const visible = getVisibleMethods({
alipay_direct: methodLimit({ single_min: 5 }),
wxpay: methodLimit({ single_max: 100 }),
@@ -43,6 +43,7 @@ describe('getVisibleMethods', () => {
expect(visible).toEqual({
alipay: methodLimit({ single_min: 5 }),
wxpay: methodLimit({ single_max: 100 }),
stripe: methodLimit({ fee_rate: 3 }),
})
})
@@ -76,6 +77,19 @@ describe('decidePaymentLaunch', () => {
expect(decision.recovery.outTradeNo).toBe('')
})
it('routes Stripe button click to the full Payment Element without a preselected sub-method', () => {
const decision = decidePaymentLaunch(createOrderResult({
client_secret: 'cs_test',
}), {
visibleMethod: 'stripe',
orderType: 'balance',
isMobile: false,
})
expect(decision.kind).toBe('stripe_route')
expect(decision.stripeMethod).toBeUndefined()
})
it('uses Stripe route flow for mobile WeChat client secret', () => {
const decision = decidePaymentLaunch(createOrderResult({
client_secret: 'cs_test',

View File

@@ -14,9 +14,10 @@ const VISIBLE_METHOD_ALIASES = {
alipay_direct: 'alipay',
wxpay: 'wxpay',
wxpay_direct: 'wxpay',
stripe: 'stripe',
} as const
export type VisiblePaymentMethod = 'alipay' | 'wxpay'
export type VisiblePaymentMethod = 'alipay' | 'wxpay' | 'stripe'
export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
export type PaymentLaunchKind =
| 'qr_waiting'
@@ -144,7 +145,12 @@ export function decidePaymentLaunch(
}, context.now)
if (baseState.clientSecret) {
const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
// visibleMethod === 'stripe' means the user clicked the dedicated Stripe button
// and should land on the full Payment Element to choose a sub-method themselves.
const isStripeButton = visibleMethod === 'stripe'
const stripeMethod: StripeVisibleMethod | undefined = isStripeButton
? undefined
: visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile
? 'stripe_popup'
: 'stripe_route'

View File

@@ -4,15 +4,6 @@
// OpenAI
const openaiModels = [
'gpt-3.5-turbo', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-16k',
'gpt-4', 'gpt-4-turbo', 'gpt-4-turbo-preview',
'gpt-4o', 'gpt-4o-2024-08-06', 'gpt-4o-2024-11-20',
'gpt-4o-mini', 'gpt-4o-mini-2024-07-18',
'gpt-4.5-preview',
'gpt-4.1', 'gpt-4.1-mini', 'gpt-4.1-nano',
'o1', 'o1-preview', 'o1-mini', 'o1-pro',
'o3', 'o3-mini', 'o3-pro',
'o4-mini',
// GPT-5.2 系列
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
@@ -22,7 +13,6 @@ const openaiModels = [
'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05',
// GPT-5.3 系列
'gpt-5.3-codex', 'gpt-5.3-codex-spark',
'chatgpt-4o-latest',
'gpt-4o-audio-preview', 'gpt-4o-realtime-preview',
// GPT Image 系列
'gpt-image-1', 'gpt-image-1.5', 'gpt-image-2'
@@ -32,7 +22,6 @@ const openaiModels = [
export const claudeModels = [
'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620',
'claude-3-5-haiku-20241022',
'claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307',
'claude-3-7-sonnet-20250219',
'claude-sonnet-4-20250514', 'claude-opus-4-20250514',
'claude-opus-4-1-20250805',
@@ -40,8 +29,7 @@ export const claudeModels = [
'claude-opus-4-5-20251101',
'claude-opus-4-6',
'claude-opus-4-7',
'claude-sonnet-4-6',
'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
'claude-sonnet-4-6'
]
// Google Gemini

View File

@@ -346,6 +346,7 @@ export default {
apiKeys: 'API Keys',
usage: 'Usage',
redeem: 'Redeem',
affiliate: 'Affiliate Rebates',
profile: 'Profile',
users: 'Users',
groups: 'Groups',
@@ -972,6 +973,47 @@ export default {
}
},
affiliate: {
title: 'Affiliate Rebates',
description: 'Invite new users and convert your rebate quota into account balance',
yourCode: 'Your Affiliate Code',
inviteLink: 'Invite Link',
copyCode: 'Copy Code',
copyLink: 'Copy Link',
codeCopied: 'Affiliate code copied',
linkCopied: 'Invite link copied',
loadFailed: 'Failed to load affiliate data',
transferFailed: 'Failed to transfer affiliate quota',
stats: {
invitedUsers: 'Invited Users',
availableQuota: 'Available Rebate Quota',
totalQuota: 'Historical Rebate Quota'
},
transfer: {
title: 'Transfer Rebate Quota',
description: 'Move available rebate quota into your account balance',
button: 'Transfer to Balance',
transferring: 'Transferring...',
empty: 'No available rebate quota',
success: '{amount} has been transferred to your balance'
},
invitees: {
title: 'Invited Users',
empty: 'No invited users yet',
columns: {
email: 'Email',
username: 'Username',
joinedAt: 'Joined At'
}
},
tips: {
title: 'How It Works',
line1: 'Share your affiliate code or invite link with new users.',
line2: 'When invitees recharge, you receive rebate quota based on the configured rate.',
line3: 'Transfer rebate quota to balance at any time.'
}
},
// Redeem
redeem: {
title: 'Redeem Code',
@@ -2806,6 +2848,22 @@ export default {
codexCLIOnly: 'Codex official clients only',
codexCLIOnlyDesc:
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
compactMode: 'Compact mode',
compactModeDesc:
'Controls how this account participates in /responses/compact routing. Auto follows probe results, Force On always allows, Force Off always excludes.',
compactModeAuto: 'Auto',
compactModeForceOn: 'Force On',
compactModeForceOff: 'Force Off',
compactModelMapping: 'Compact-only model mapping',
compactModelMappingDesc:
'Only applies to /responses/compact. Use this when the upstream compact endpoint requires a special compact model.',
compactSupported: 'Compact supported',
compactUnsupported: 'Compact unsupported',
compactUnknown: 'Compact unknown',
compactLastChecked: 'Last compact probe',
testMode: 'Test mode',
testModeDefault: 'Default request',
testModeCompact: 'Compact probe',
modelRestrictionDisabledByPassthrough: 'Automatic passthrough is enabled: model whitelist/mapping will not take effect.',
},
anthropic: {
@@ -4837,6 +4895,9 @@ export default {
description: 'Default values for new users',
defaultBalance: 'Default Balance',
defaultBalanceHint: 'Initial balance for new users',
affiliateRebateRate: 'Affiliate Rebate Rate',
affiliateRebateRateHint:
'Rebate percentage credited to inviter after recharge (0-100%, e.g. 10 means 10%)',
defaultConcurrency: 'Default Concurrency',
defaultConcurrencyHint: 'Maximum concurrent requests for new users',
defaultUserRpmLimit: 'Default User RPM Limit',

View File

@@ -346,6 +346,7 @@ export default {
apiKeys: 'API 密钥',
usage: '使用记录',
redeem: '兑换',
affiliate: '邀请返利',
profile: '个人资料',
users: '用户管理',
groups: '分组管理',
@@ -976,6 +977,47 @@ export default {
}
},
affiliate: {
title: '邀请返利',
description: '邀请新用户注册,并将返利额度转入账户余额',
yourCode: '我的邀请码',
inviteLink: '邀请链接',
copyCode: '复制邀请码',
copyLink: '复制链接',
codeCopied: '邀请码已复制',
linkCopied: '邀请链接已复制',
loadFailed: '加载邀请返利数据失败',
transferFailed: '转入余额失败',
stats: {
invitedUsers: '邀请人数',
availableQuota: '可转返利额度',
totalQuota: '历史返利额度'
},
transfer: {
title: '返利额度转余额',
description: '将当前可用返利额度一键转入账户余额',
button: '转入余额',
transferring: '转入中...',
empty: '当前没有可转入额度',
success: '已转入余额:{amount}'
},
invitees: {
title: '已邀请用户',
empty: '暂无邀请记录',
columns: {
email: '邮箱',
username: '用户名',
joinedAt: '注册时间'
}
},
tips: {
title: '使用说明',
line1: '将邀请码或邀请链接分享给新用户。',
line2: '被邀请用户充值后,你可获得对应比例的返利额度。',
line3: '返利额度可随时转入账户余额。'
}
},
// Redeem
redeem: {
title: '兑换码',
@@ -2951,6 +2993,22 @@ export default {
responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
codexCLIOnly: '仅允许 Codex 官方客户端',
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
compactMode: 'Compact 模式',
compactModeDesc:
'控制本账号在 /responses/compact 调度中的参与方式。Auto 跟随探测结果Force On 强制允许Force Off 强制排除。',
compactModeAuto: '自动',
compactModeForceOn: '强制开启',
compactModeForceOff: '强制关闭',
compactModelMapping: 'Compact 专属模型映射',
compactModelMappingDesc:
'仅在 /responses/compact 请求中生效。当上游 compact 端点需要特殊 compact 模型时使用。',
compactSupported: '支持 Compact',
compactUnsupported: '不支持 Compact',
compactUnknown: 'Compact 未知',
compactLastChecked: '最近探测',
testMode: '测试模式',
testModeDefault: '常规请求',
testModeCompact: 'Compact 探测',
modelRestrictionDisabledByPassthrough: '已开启自动透传:模型白名单/映射不会生效。',
},
anthropic: {
@@ -5000,6 +5058,8 @@ export default {
description: '新用户的默认值',
defaultBalance: '默认余额',
defaultBalanceHint: '新用户的初始余额',
affiliateRebateRate: '邀请返利比例',
affiliateRebateRateHint: '充值后返给邀请人的比例0-100%,例如填写 10 表示返利 10%',
defaultConcurrency: '默认并发数',
defaultConcurrencyHint: '新用户的最大并发请求数',
defaultUserRpmLimit: '默认用户 RPM 限制',

View File

@@ -197,6 +197,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'redeem.description'
}
},
{
path: '/affiliate',
name: 'Affiliate',
component: () => import('@/views/user/AffiliateView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: false,
title: 'Affiliate',
titleKey: 'affiliate.title',
descriptionKey: 'affiliate.description'
}
},
{
path: '/available-channels',
name: 'UserAvailableChannels',

View File

@@ -122,6 +122,29 @@ export interface RegisterRequest {
turnstile_token?: string
promo_code?: string
invitation_code?: string
aff_code?: string
}
export interface AffiliateInvitee {
user_id: number
email: string
username: string
created_at?: string
}
export interface UserAffiliateDetail {
user_id: number
aff_code: string
inviter_id?: number | null
aff_count: number
aff_quota: number
aff_history_quota: number
invitees: AffiliateInvitee[]
}
export interface AffiliateTransferResponse {
transferred_quota: number
balance: number
}
export interface SendVerifyCodeRequest {
@@ -744,8 +767,8 @@ export interface Account {
platform: AccountPlatform
type: AccountType
credentials?: Record<string, unknown>
// Extra fields including Codex usage and model-level rate limits (Antigravity smart retry)
extra?: (CodexUsageSnapshot & {
// Extra fields including Codex usage, OpenAI compact capability, and model-level rate limits.
extra?: (CodexUsageSnapshot & OpenAICompactState & {
model_rate_limits?: Record<string, { rate_limited_at: string; rate_limit_reset_at: string }>
antigravity_credits_overages?: Record<string, { activated_at: string; active_until: string }>
} & Record<string, unknown>)
@@ -917,6 +940,16 @@ export interface CodexUsageSnapshot {
codex_usage_updated_at?: string // Last update timestamp
}
export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off'
export interface OpenAICompactState {
openai_compact_mode?: OpenAICompactMode
openai_compact_supported?: boolean
openai_compact_checked_at?: string
openai_compact_last_status?: number
openai_compact_last_error?: string
}
export interface CreateAccountRequest {
name: string
notes?: string | null

View File

@@ -188,6 +188,13 @@
<template #cell-platform_type="{ row }">
<div class="flex flex-wrap items-center gap-1">
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" :subscription-expires-at="row.credentials?.subscription_expires_at" />
<span
v-if="getOpenAICompactLabel(row)"
:class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getOpenAICompactClass(row)]"
:title="getOpenAICompactTitle(row)"
>
{{ getOpenAICompactLabel(row) }}
</span>
<span
v-if="getAntigravityTierLabel(row)"
:class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getAntigravityTierClass(row)]"
@@ -932,6 +939,43 @@ function getAntigravityTierLabel(row: any): string | null {
}
}
function getOpenAICompactState(row: any): 'supported' | 'unsupported' | 'unknown' | null {
if (row.platform !== 'openai' || (row.type !== 'oauth' && row.type !== 'apikey')) return null
const extra = row.extra as Record<string, unknown> | undefined
const mode = typeof extra?.openai_compact_mode === 'string' ? extra.openai_compact_mode : 'auto'
if (mode === 'force_on') return 'supported'
if (mode === 'force_off') return 'unsupported'
if (typeof extra?.openai_compact_supported === 'boolean') {
return extra.openai_compact_supported ? 'supported' : 'unsupported'
}
return 'unknown'
}
function getOpenAICompactLabel(row: any): string | null {
switch (getOpenAICompactState(row)) {
case 'supported': return t('admin.accounts.openai.compactSupported')
case 'unsupported': return t('admin.accounts.openai.compactUnsupported')
case 'unknown': return t('admin.accounts.openai.compactUnknown')
default: return null
}
}
function getOpenAICompactClass(row: any): string {
switch (getOpenAICompactState(row)) {
case 'supported': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/40 dark:text-emerald-300'
case 'unsupported': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/40 dark:text-rose-300'
case 'unknown': return 'bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300'
default: return ''
}
}
function getOpenAICompactTitle(row: any): string {
const extra = row.extra as Record<string, unknown> | undefined
const checkedAt = typeof extra?.openai_compact_checked_at === 'string' ? extra.openai_compact_checked_at : ''
if (!checkedAt) return getOpenAICompactLabel(row) || ''
return `${getOpenAICompactLabel(row)} | ${t('admin.accounts.openai.compactLastChecked')}: ${formatDateTime(new Date(checkedAt))}`
}
function getAntigravityTierClass(row: any): string {
const tier = getAntigravityTierFromRow(row)
switch (tier) {

View File

@@ -2153,6 +2153,31 @@
{{ t("admin.settings.defaults.defaultBalanceHint") }}
</p>
</div>
<div>
<label
class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"
>
{{ t("admin.settings.defaults.affiliateRebateRate") }}
</label>
<div class="relative">
<input
v-model.number="form.affiliate_rebate_rate"
type="number"
step="0.01"
min="0"
max="100"
class="input pr-8"
placeholder="20"
/>
<span
class="pointer-events-none absolute right-3 top-1/2 -translate-y-1/2 text-gray-400"
>%</span
>
</div>
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.settings.defaults.affiliateRebateRateHint") }}
</p>
</div>
<div>
<label
class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"
@@ -4972,6 +4997,7 @@ const form = reactive<SettingsForm>({
totp_enabled: false,
totp_encryption_key_configured: false,
default_balance: 0,
affiliate_rebate_rate: 20,
default_concurrency: 1,
default_subscriptions: [],
force_email_on_third_party_signup: false,
@@ -5894,6 +5920,10 @@ async function saveSettings() {
password_reset_enabled: form.password_reset_enabled,
totp_enabled: form.totp_enabled,
default_balance: form.default_balance,
affiliate_rebate_rate: Math.min(
100,
Math.max(0, Number(form.affiliate_rebate_rate) || 0),
),
default_concurrency: form.default_concurrency,
default_subscriptions: normalizedDefaultSubscriptions,
force_email_on_third_party_signup: form.force_email_on_third_party_signup,

View File

@@ -209,6 +209,7 @@ const password = ref<string>('')
const initialTurnstileToken = ref<string>('')
const promoCode = ref<string>('')
const invitationCode = ref<string>('')
const affCode = ref<string>('')
const pendingAuthToken = ref<string>('')
const pendingAuthTokenField = ref<PendingAuthTokenField>('pending_auth_token')
const pendingProvider = ref<string>('')
@@ -260,6 +261,7 @@ onMounted(async () => {
initialTurnstileToken.value = registerData.turnstile_token || ''
promoCode.value = registerData.promo_code || ''
invitationCode.value = registerData.invitation_code || ''
affCode.value = registerData.aff_code || ''
pendingAuthToken.value = registerData.pending_auth_token || activePendingSession?.token || ''
pendingAuthTokenField.value = registerData.pending_auth_token_field || activePendingSession?.token_field || 'pending_auth_token'
pendingProvider.value = registerData.pending_provider || activePendingSession?.provider || ''
@@ -524,7 +526,8 @@ async function handleVerify(): Promise<void> {
verify_code: verifyCode.value.trim(),
turnstile_token: initialTurnstileToken.value || undefined,
promo_code: promoCode.value || undefined,
invitation_code: invitationCode.value || undefined
invitation_code: invitationCode.value || undefined,
...(affCode.value ? { aff_code: affCode.value } : {})
})
}

View File

@@ -351,7 +351,8 @@ const formData = reactive({
email: '',
password: '',
promo_code: '',
invitation_code: ''
invitation_code: '',
aff_code: ''
})
const errors = reactive({
@@ -406,6 +407,10 @@ onMounted(async () => {
await validatePromoCodeDebounced(promoParam)
}
}
const affParam = (route.query.aff as string) || (route.query.aff_code as string)
if (affParam) {
formData.aff_code = affParam.trim()
}
} catch (error) {
console.error('Failed to load public settings:', error)
} finally {
@@ -707,7 +712,8 @@ async function handleRegister(): Promise<void> {
password: formData.password,
turnstile_token: turnstileToken.value,
promo_code: formData.promo_code || undefined,
invitation_code: formData.invitation_code || undefined
invitation_code: formData.invitation_code || undefined,
...(formData.aff_code ? { aff_code: formData.aff_code } : {})
})
)
@@ -722,7 +728,8 @@ async function handleRegister(): Promise<void> {
password: formData.password,
turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined,
promo_code: formData.promo_code || undefined,
invitation_code: formData.invitation_code || undefined
invitation_code: formData.invitation_code || undefined,
...(formData.aff_code ? { aff_code: formData.aff_code } : {})
})
// Show success toast

View File

@@ -0,0 +1,201 @@
<template>
<AppLayout>
<div class="space-y-6">
<div v-if="loading" class="flex justify-center py-12">
<div
class="h-8 w-8 animate-spin rounded-full border-2 border-primary-500 border-t-transparent"
></div>
</div>
<template v-else-if="detail">
<div class="grid gap-4 md:grid-cols-3">
<div class="card p-5">
<p class="text-sm text-gray-500 dark:text-dark-400">{{ t('affiliate.stats.invitedUsers') }}</p>
<p class="mt-2 text-2xl font-semibold text-gray-900 dark:text-white">
{{ formatCount(detail.aff_count) }}
</p>
</div>
<div class="card p-5">
<p class="text-sm text-gray-500 dark:text-dark-400">{{ t('affiliate.stats.availableQuota') }}</p>
<p class="mt-2 text-2xl font-semibold text-emerald-600 dark:text-emerald-400">
{{ formatCurrency(detail.aff_quota) }}
</p>
</div>
<div class="card p-5">
<p class="text-sm text-gray-500 dark:text-dark-400">{{ t('affiliate.stats.totalQuota') }}</p>
<p class="mt-2 text-2xl font-semibold text-gray-900 dark:text-white">
{{ formatCurrency(detail.aff_history_quota) }}
</p>
</div>
</div>
<div class="card p-6">
<h3 class="text-base font-semibold text-gray-900 dark:text-white">{{ t('affiliate.title') }}</h3>
<p class="mt-1 text-sm text-gray-500 dark:text-dark-400">{{ t('affiliate.description') }}</p>
<div class="mt-5 grid gap-4 md:grid-cols-2">
<div class="space-y-2">
<p class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('affiliate.yourCode') }}</p>
<div class="flex items-center gap-2 rounded-xl border border-gray-200 bg-gray-50 px-3 py-2 dark:border-dark-700 dark:bg-dark-900">
<code class="flex-1 truncate text-sm font-semibold text-gray-900 dark:text-white">{{ detail.aff_code }}</code>
<button class="btn btn-secondary btn-sm" @click="copyCode">
<Icon name="copy" size="sm" />
<span>{{ t('affiliate.copyCode') }}</span>
</button>
</div>
</div>
<div class="space-y-2">
<p class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('affiliate.inviteLink') }}</p>
<div class="flex items-center gap-2 rounded-xl border border-gray-200 bg-gray-50 px-3 py-2 dark:border-dark-700 dark:bg-dark-900">
<code class="flex-1 truncate text-sm text-gray-700 dark:text-gray-300">{{ inviteLink }}</code>
<button class="btn btn-secondary btn-sm" @click="copyInviteLink">
<Icon name="copy" size="sm" />
<span>{{ t('affiliate.copyLink') }}</span>
</button>
</div>
</div>
</div>
<div class="mt-5 rounded-xl border border-primary-200 bg-primary-50 p-4 dark:border-primary-900/40 dark:bg-primary-900/20">
<p class="text-sm font-medium text-primary-800 dark:text-primary-200">{{ t('affiliate.tips.title') }}</p>
<ul class="mt-2 space-y-1 text-sm text-primary-700 dark:text-primary-300">
<li>1. {{ t('affiliate.tips.line1') }}</li>
<li>2. {{ t('affiliate.tips.line2') }}</li>
<li>3. {{ t('affiliate.tips.line3') }}</li>
</ul>
</div>
</div>
<div class="card p-6">
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
<div>
<h3 class="text-base font-semibold text-gray-900 dark:text-white">{{ t('affiliate.transfer.title') }}</h3>
<p class="mt-1 text-sm text-gray-500 dark:text-dark-400">{{ t('affiliate.transfer.description') }}</p>
</div>
<button
class="btn btn-primary"
:disabled="transferring || detail.aff_quota <= 0"
@click="transferQuota"
>
<Icon v-if="transferring" name="refresh" size="sm" class="animate-spin" />
<Icon v-else name="dollar" size="sm" />
<span>{{ transferring ? t('affiliate.transfer.transferring') : t('affiliate.transfer.button') }}</span>
</button>
</div>
<p v-if="detail.aff_quota <= 0" class="mt-3 text-sm text-amber-600 dark:text-amber-400">
{{ t('affiliate.transfer.empty') }}
</p>
</div>
<div class="card p-6">
<h3 class="text-base font-semibold text-gray-900 dark:text-white">{{ t('affiliate.invitees.title') }}</h3>
<div v-if="detail.invitees.length === 0" class="mt-4 rounded-xl border border-dashed border-gray-300 p-6 text-center text-sm text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('affiliate.invitees.empty') }}
</div>
<div v-else class="mt-4 overflow-x-auto">
<table class="w-full min-w-[560px] text-left text-sm">
<thead>
<tr class="border-b border-gray-200 text-gray-500 dark:border-dark-700 dark:text-dark-400">
<th class="px-3 py-2 font-medium">{{ t('affiliate.invitees.columns.email') }}</th>
<th class="px-3 py-2 font-medium">{{ t('affiliate.invitees.columns.username') }}</th>
<th class="px-3 py-2 font-medium">{{ t('affiliate.invitees.columns.joinedAt') }}</th>
</tr>
</thead>
<tbody>
<tr
v-for="item in detail.invitees"
:key="item.user_id"
class="border-b border-gray-100 last:border-b-0 dark:border-dark-800"
>
<td class="px-3 py-3 text-gray-900 dark:text-white">{{ item.email || '-' }}</td>
<td class="px-3 py-3 text-gray-700 dark:text-gray-300">{{ item.username || '-' }}</td>
<td class="px-3 py-3 text-gray-700 dark:text-gray-300">{{ formatDateTime(item.created_at) || '-' }}</td>
</tr>
</tbody>
</table>
</div>
</div>
</template>
</div>
</AppLayout>
</template>
<script setup lang="ts">
import { computed, onMounted, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue'
import userAPI from '@/api/user'
import type { UserAffiliateDetail } from '@/types'
import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { useClipboard } from '@/composables/useClipboard'
import { formatCurrency, formatDateTime } from '@/utils/format'
import { extractApiErrorMessage } from '@/utils/apiError'
const { t } = useI18n()
const appStore = useAppStore()
const authStore = useAuthStore()
const { copyToClipboard } = useClipboard()
const loading = ref(true)
const transferring = ref(false)
const detail = ref<UserAffiliateDetail | null>(null)
const inviteLink = computed(() => {
if (!detail.value) return ''
if (typeof window === 'undefined') return `/register?aff=${encodeURIComponent(detail.value.aff_code)}`
return `${window.location.origin}/register?aff=${encodeURIComponent(detail.value.aff_code)}`
})
function formatCount(value: number): string {
return value.toLocaleString()
}
async function loadAffiliateDetail(silent = false): Promise<void> {
if (!silent) {
loading.value = true
}
try {
detail.value = await userAPI.getAffiliateDetail()
} catch (error) {
appStore.showError(extractApiErrorMessage(error, t('affiliate.loadFailed')))
} finally {
if (!silent) {
loading.value = false
}
}
}
async function copyCode(): Promise<void> {
if (!detail.value?.aff_code) return
await copyToClipboard(detail.value.aff_code, t('affiliate.codeCopied'))
}
async function copyInviteLink(): Promise<void> {
if (!inviteLink.value) return
await copyToClipboard(inviteLink.value, t('affiliate.linkCopied'))
}
async function transferQuota(): Promise<void> {
if (!detail.value || detail.value.aff_quota <= 0 || transferring.value) return
transferring.value = true
try {
const resp = await userAPI.transferAffiliateQuota()
appStore.showSuccess(t('affiliate.transfer.success', { amount: formatCurrency(resp.transferred_quota) }))
await Promise.all([
loadAffiliateDetail(true),
authStore.refreshUser().catch(() => undefined),
])
} catch (error) {
appStore.showError(extractApiErrorMessage(error, t('affiliate.transferFailed')))
} finally {
transferring.value = false
}
}
onMounted(() => {
void loadAffiliateDetail()
})
</script>

Some files were not shown because too many files have changed in this diff Show More