Compare commits

...

51 Commits

Author SHA1 Message Date
Wesley Liddick
a0b5e5bfa0 Merge pull request #1973 from Nobody-Zhang/main
fix(payment): 修复 Zpay 退款接口调用
2026-04-26 13:11:42 +08:00
Wesley Liddick
41d0657330 Merge pull request #1970 from deqiying/fix-1754-claude-openai-cache-usage
fix(anthropic): 修正缓存 token 的 Anthropic 用量语义
2026-04-26 13:08:18 +08:00
Nobody-Zhang
1a0cabbfd6 Fix Zpay refund endpoint handling 2026-04-26 04:57:34 +00:00
shaw
9b6dcc57bd feat(affiliate): 完善邀请返利系统
- 修复返利不到账的根因:tryClaimAffiliateRebateAudit 中 PostgreSQL 参数类型推断冲突
  - 补全 OAuth 注册路径(LinuxDo/OIDC/WeChat/Pending Flow)的邀请码绑定
  - 前端 OAuth 注册页面传递 aff_code 参数
  - 新增返利冻结期机制:可配置冻结时间,到期后自动解冻(懒解冻)
  - 新增返利有效期:绑定后 N 天内有效,过期不再产生返利
  - 新增单人返利上限:超出上限部分精确截断
  - 增强返利流程 slog 结构化日志,便于排查问题
  - 已邀请用户列表增加返利明细列
2026-04-26 12:42:35 +08:00
deqiying
b17704d6ef fix(anthropic): 修正缓存 token 的 Anthropic 用量语义 2026-04-26 01:14:59 +08:00
shaw
496469ac4e fix(gateway): skip body mimicry for real Claude Code clients to restore prompt caching
PR #1914 unconditionally applied the full mimicry pipeline to all OAuth
accounts, including real Claude Code CLI clients. This replaced the
client's long system prompt (~10K+ tokens with stable cache_control
breakpoints) with a short ~45 token [billing, CC prompt] pair, which
falls below Anthropic's 1024-token minimum cacheable prefix threshold.
The result: every request created a new cache but never hit an existing
one.

Fix: restore the Claude Code client detection gate so that real CC
clients bypass body-level mimicry (system rewrite, message cache
management, tool name obfuscation). Non-CC third-party clients
(opencode, etc.) continue to receive full mimicry.

Also harden the detection logic:
- Make UA regex case-insensitive (align with claude_code_validator.go)
- Validate metadata.user_id format via ParseMetadataUserID() instead of
  just checking non-empty, preventing third-party tools from spoofing
  a claude-cli/* UA with an arbitrary user_id string to bypass mimicry
2026-04-25 22:50:35 +08:00
shaw
c1b52615be fix(payment): allow Stripe payment pages to bypass router auth guard
Stripe payment routes (/payment/stripe, /payment/stripe-popup) are
reached via hard navigation (window.location.href), which caused
the router guard to block access before the page could load.
Set requiresAuth and requiresPayment to false, consistent with
/payment/result. Backend API still enforces authentication.
2026-04-25 21:38:40 +08:00
shaw
3af9940b85 style: fix gofmt and ineffassign lint errors
- gofmt: realign AffiliateDetail struct tags in affiliate_service.go
- ineffassign: remove dead seenCompleted assignment before return in account_test_service.go
2026-04-25 20:37:42 +08:00
Wesley Liddick
22b1277572 Merge pull request #1948 from hungryboy1025/fix/openai-account-test-responses-stream
fix(openai): tighten responses stream account tests
2026-04-25 20:31:07 +08:00
Wesley Liddick
aff98d5ae1 Merge pull request #1960 from gaoren002/fix/openai-stream-keepalive-downstream-idle
fix(openai): keep responses stream alive during pre-output failover
2026-04-25 20:24:25 +08:00
shaw
4e1bb2b445 feat(affiliate): add feature toggle and per-user custom invite settings
- 在系统设置「功能开关」中新增邀请返利总开关,默认关闭;
  关闭态:菜单隐藏、注册忽略 aff、新充值不返利,但已有 quota 仍可转余额
- 支持管理员为指定用户设置专属邀请码(覆盖随机码,全局唯一)
- 支持管理员为指定用户设置专属返利比例(覆盖全局比例,可单条/批量调整)
- 在系统设置邀请返利卡片内嵌入专属用户管理表格(搜索/编辑/批量/删除),
  删除采用项目通用 ConfirmDialog,会同时清除专属比例并把邀请码重置为系统随机码
- /affiliate 用户页新增「我的返利比例」卡片与动态使用说明,让用户直观看到
  分享后能拿到多少(同源 resolveRebateRatePercent 计算,与实际充值一致)
- 新增数据库迁移 132 添加 aff_rebate_rate_percent 与 aff_code_custom 列
- 新增 admin 路由组 /api/v1/admin/affiliates/users/* 共 5 个端点
- AffiliateService 改为只依赖 *SettingService,去除冗余的 SettingRepository
- 邀请码格式校验放宽到 [A-Z0-9_-]{4,32},兼容旧 12 位系统码与新自定义码
- 补充单元测试与集成测试覆盖新方法、冲突路径与边界值
2026-04-25 20:22:07 +08:00
gaoren002
dac6e52091 fix(openai): keep responses stream alive during pre-output failover 2026-04-25 12:11:27 +00:00
hungryboy1025
8987e0ba67 fix(openai): tighten responses stream account tests 2026-04-25 16:56:50 +08:00
github-actions[bot]
9d1751ec57 chore: sync VERSION to 0.1.118 [skip ci] 2026-04-25 08:06:21 +00:00
Wesley Liddick
5d1c12e60e Merge pull request #1943 from AyeSt0/fix/openai-responses-preoutput-failover
fix(openai): 修复 Responses 流式失败前置事件导致无法 failover
2026-04-25 15:43:00 +08:00
AyeSt0
5b63a9b02d fix(openai): fail over before responses stream output 2026-04-25 15:09:40 +08:00
Wesley Liddick
641e61073f Merge pull request #1940 from 4fuu/fix/bump-codex-cli-version-to-0.125.0
fix(openai): bump codex CLI version from 0.104.0 to 0.125.0
2026-04-25 14:57:51 +08:00
shaw
095f457c57 feat(openai): port /responses/compact account support flow (PR #1555)
vansour/sub2api#1555 的 OpenAI compact 能力建模手工移植到当前 main:账号
级 compact 状态/auto-force_on-force_off 模式、compact-only 模型映射、调度器
tier 分层(已支持 > 未知 > 已知不支持)、管理后台 compact 主动探测,以及对应
i18n/状态徽章。普通 /responses 流量行为不变,无数据库迁移。
2026-04-25 14:52:58 +08:00
4fuu
1e57e88e43 fix(openai): bump codex CLI version from 0.104.0 to 0.125.0
The hardcoded codex CLI version (0.104.0) causes upstream rejection
when using gpt-5.5 with compact, as the server treats the request
as an outdated client and returns 400/502.

Update codexCLIVersion, codexCLIUserAgent, and openAICodexProbeVersion
to 0.125.0 to match the current Codex CLI release.

Fixes #1933, #1887, #1865
Related: #1609, #1298, #849
2026-04-25 05:26:33 +00:00
Wesley Liddick
b95ffce244 Merge pull request #1772 from KnowSky404/fix/openai-test-state-reconciliation
[codex] reconcile OpenAI admin test rate-limit state
2026-04-25 10:02:21 +08:00
shaw
8f28a834f8 fix(payment): 同时启用易支付和 Stripe 时显示 Stripe 按钮
VISIBLE_METHOD_ALIASES 漏了 stripe,导致 getVisibleMethods 把后端返回
的 stripe 过滤掉。点 Stripe 按钮时省略 method 查询参数,让落地页渲染
完整的 Payment Element。
2026-04-25 09:46:27 +08:00
shaw
7424c73b05 chore: remove unused model IDs 2026-04-25 09:04:34 +08:00
Wesley Liddick
1afd81b019 Merge pull request #1920 from Wuxie233/fix/responses-web-search-tool-types
fix(apicompat): recognize web_search_20250305 / google_search in Responses→Anthropic tool conversion
2026-04-25 09:00:37 +08:00
shaw
732d6495ea chore(gateway): fix lint issues from cc-mimicry-parity merge
- staticcheck QF1001: apply De Morgan's law to the OAuth-mimic header
  passthrough guard (`!(a && b)` → `a != ... || !b`).
- unused: drop `isClaudeCodeRequest`, which became dead after PR #1914
  switched both `/v1/messages` and `/count_tokens` paths to unconditional
  `account.IsOAuth()` mimicry. The lowercase helper `isClaudeCodeClient`
  is kept (still referenced by `TestIsClaudeCodeClient`).
2026-04-25 08:58:57 +08:00
Wesley Liddick
6d20ab8082 Merge pull request #1914 from keh4l/feat/cc-mimicry-parity
fix(claude): align Claude Code OAuth mimicry with real CLI traffic
2026-04-25 08:54:04 +08:00
shaw
aa8ee33b0a refactor(affiliate): tighten DI and harden inviter code validation
- Drop SetAffiliateService setters and ProvideAuthService /
  ProvidePaymentService / ProvideUserHandler wrappers in favor of direct
  Wire constructor injection. AffiliateService has no back-edge to
  Auth/Payment/User, so the indirection was never required.
- Change RegisterWithVerification's variadic affiliateCode to a fixed
  parameter; adjust all call sites.
- Validate aff_code length and charset in BindInviterByCode before any
  DB lookup, eliminating timing-side-channel and useless DB roundtrips
  on malformed input.
- Make affiliate cache invalidation synchronous; surface Redis errors
  via the project logger instead of swallowing them in a detached
  goroutine.
- Add an integration test guarding cross-layer tx propagation in
  AccrueQuota and a unit test pinning the aff_code format rules.
2026-04-25 08:44:18 +08:00
Wuxie233
5f630fbb19 fix(apicompat): recognize web_search_20250305 / google_search in Responses to Anthropic tool conversion 2026-04-25 01:09:51 +08:00
keh4l
bdbd2916f5 fix(gateway): skip client header passthrough on OAuth mimicry path
Root cause of persistent third-party detection: sub2api's
buildUpstreamRequest transparently forwards client headers via
allowedHeaders whitelist (addHeaderRaw) before applying mimicry
overrides. When third-party clients (opencode, etc.) send their own
anthropic-beta / user-agent / x-stainless-* / x-claude-code-session-id
values, these get appended to the request alongside our injected
headers, creating an inconsistent header set that Anthropic detects.

Parrot's build_upstream_headers constructs exactly 9 headers from
scratch and never forwards anything from the client. This is why
'same opencode version, some users work some don't' — different
opencode configs/versions send different header combinations.

Fix: when tokenType=oauth and mimicClaudeCode=true, skip the
client header passthrough loop entirely. The subsequent
applyClaudeCodeMimicHeaders + ApplyFingerprint + beta merge
pipeline constructs all necessary headers from our controlled values.

Also: remove systemIncludesClaudeCodePrompt gate — OAuth accounts
now unconditionally rewrite system (even if client already sent a
Claude Code-style prompt), ensuring billing attribution block is
always present.
2026-04-25 00:43:38 +08:00
keh4l
6dc89765fd fix(gateway): always apply full mimicry for OAuth accounts regardless of client identity
Before: isClaudeCodeRequest() checked whether the client looks like a
real Claude Code CLI (UA, system prompt, X-App header, metadata format).
If it looked like Claude Code, all mimicry was skipped — the assumption
being that a real CLI needs no help.

Problem: third-party tools like opencode partially impersonate Claude
Code (sending claude-cli UA + claude-code beta + CC system prompt) but
miss critical details (billing attribution block, tool-name obfuscation,
cache breakpoints, full beta set). Some users' opencode instances pass
the isClaudeCodeRequest check, causing sub2api to skip mimicry entirely,
while Anthropic still detects the request as third-party.

This explains why 'same opencode version, some users work, some don't'
— it depends on which opencode features/config trigger the validator.

Fix: OAuth accounts now unconditionally run the full mimicry pipeline,
matching Parrot's behavior (Parrot never checks client identity).
This is safe because our mimicry is strictly more complete than any
third-party client's partial impersonation.

Changed:
  - /v1/messages path: remove isClaudeCode gate
  - /v1/messages/count_tokens path: same
2026-04-25 00:26:37 +08:00
keh4l
f3233db01f fix(gateway): apply D/E/F mimicry to native /v1/messages and count_tokens paths
The previous commit only wired stripMessageCacheControl,
addMessageCacheBreakpoints, and tool-name obfuscation into
applyClaudeCodeOAuthMimicryToBody (used by /chat/completions and
/responses). The native /v1/messages path and count_tokens path
have their own independent mimicry code blocks and were missed.

Now all three entry points share the same D/E/F pipeline:
  - /v1/messages (gateway_service.go forwardAnthropic)
  - /v1/messages/count_tokens (gateway_service.go countTokens)
  - OpenAI compat (applyClaudeCodeOAuthMimicryToBody)
2026-04-24 23:16:32 +08:00
keh4l
6e12578bc5 feat(gateway): port Parrot tool-name obfuscation + message cache breakpoints
Implements the remaining three parity items with Parrot cc_mimicry:

  D) Tool-name obfuscation
     - Dynamic mapping when tools.length > 5 (matches Parrot threshold).
       Fake names follow {prefix}{name[:3]}{i:02d} (e.g. 'manage_bas00').
       Go port of random.Random(hash(tuple(names))) uses fnv64a seed +
       math/rand; byte-exact reproduction is impossible (Python hash vs
       Go hash), but the two invariants that matter are preserved:
         * same input tool_names yield identical mapping (cache hit)
         * prefix pool is shuffled (names look distributed)
     - Static prefix map (sessions_ -> cc_sess_, session_ -> cc_ses_)
       applied as fallback, matching Parrot TOOL_NAME_REWRITES verbatim.
     - Server tools (web_search_20250305, computer_*, etc.) are NOT
       renamed; only type=='function' and type=='custom' tools are.
     - tool_choice.name is rewritten in sync (only when type=='tool').
     - Response side: bytes-level replace on every SSE chunk / JSON
       body at 6 injection points (standard stream/non-stream,
       passthrough stream/non-stream, chat_completions stream +
       non-stream, responses stream + non-stream). Reverse mapping
       applied longest-fake-name-first to prevent substring conflicts
       (parity with Parrot _restore_tool_names_in_chunk).
     - tool_choice is no longer unconditionally deleted in
       normalizeClaudeOAuthRequestBody — Parrot passes it through.

  E) tools[-1] cache_control breakpoint
     - Injected as {type:ephemeral, ttl:<DefaultCacheControlTTL>} when
       the last tool has no cache_control. Client-provided ttl is
       passed through unchanged (repo-wide policy).

  F) messages cache_control strategy
     - stripMessageCacheControl removes every client-provided
       messages[*].content[*].cache_control (multi-turn stability).
     - addMessageCacheBreakpoints then injects two stable breakpoints:
       (1) last message, and (2) second-to-last user turn when
       messages.length >= 4.
     - Combined with the system block breakpoint and tools[-1]
       breakpoint, this gives exactly the 4 breakpoints Anthropic
       allows per request.

Non-trivial implementation details to be aware of when rebasing:

  * Two new files, no upstream collision:
      gateway_tool_rewrite.go       (D + E algorithms)
      gateway_messages_cache.go     (F strip + breakpoints)
  * Two new feature calls bolted onto the tail of
    applyClaudeCodeOAuthMimicryToBody in gateway_service.go — rebase
    conflicts will be ~10 lines maximum.
  * Response-side injection points all wrap their existing write with
    reverseToolNamesIfPresent(c, ...), preserving original behavior
    when no mapping is stored (static prefix rollback still runs).
  * Non-stream chat/responses switched from c.JSON to
    json.Marshal + c.Data so bytes-level replace is possible.
  * Retry bodies (FilterThinkingBlocksForRetry,
    FilterSignatureSensitiveBlocksForRetry, RectifyThinkingBudget)
    only prune blocks — they preserve the already-obfuscated tool
    names, so no extra mapping re-application is needed.

Manual QA: end-to-end scenario verified with 6 tools (above threshold)
and tool_choice.type=='tool'. Obfuscation + restore roundtrip shown
in test logs; then removed the temp test file.

Tests (16 new):
  - buildDynamicToolMap stability + below-threshold guard
  - sanitizeToolName precedence (dynamic > static)
  - restoreToolNamesInBytes longest-first + static rollback
  - applyToolNameRewriteToBody skips server tools + syncs tool_choice
  - applyToolsLastCacheBreakpoint defaults to 5m + passes client ttl
  - stripMessageCacheControl + addMessageCacheBreakpoints in the
    1/4/string-content cases + second-to-last user turn selection
  - buildToolNameRewriteFromBody ReverseOrdered is desc-by-fake-length
  - fake name shape follows Parrot {prefix}{head3}{i:02d}
2026-04-24 23:16:32 +08:00
keh4l
a25faecadd feat(gateway): align body shape with real Claude Code CLI defaults
Three field-level alignments in normalizeClaudeOAuthRequestBody to
match real Claude Code CLI traffic byte-for-byte:

  1. temperature: previously deleted unconditionally; now passes
     through client value, defaults to 1 when absent (real CLI
     always sends temperature, default 1).

  2. max_tokens: defaults to 128000 when absent (real CLI default).

  3. context_management: when thinking.type is enabled/adaptive
     and the client did not provide context_management, inject
     {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}
     to mirror real CLI behavior.

tool_choice removal is unchanged (Claude Code OAuth credentials
do not allow client-supplied tool_choice).

Tests updated:
  - gateway_body_order_test.go: temperature/max_tokens are now
    expected in output; tool_choice still removed.
  - gateway_prompt_test.go: system array is now 2 blocks
    (billing + cc prompt), assertions adjusted.
  - gateway_anthropic_apikey_passthrough_test.go: same 2-block
    assertion.
2026-04-24 23:16:32 +08:00
keh4l
5862e2d8d9 feat(gateway): add billing attribution block with cc_version fingerprint
Real Claude Code CLI always sends a 2-block system array:

  [0] {"type":"text", "text":"x-anthropic-billing-header: cc_version=X.Y.Z.{fp}; cc_entrypoint=cli; cch=00000;"}
  [1] {"type":"text", "text":"You are Claude Code...", "cache_control":{...}}

Before this commit, sub2api's mimicry path only produced block [1].
The missing billing block is one of the primary third-party detection
signals Anthropic uses for Claude-Code-scoped OAuth tokens.

New file gateway_billing_block.go ports the fingerprint algorithm
(byte-for-byte from Parrot cc_mimicry.py:compute_fingerprint):
pick chars at positions [4,7,20] of the first user text, then
`sha256(SALT + chars + cc_version)[:3]`.

  - claude/constants.go: CLICurrentVersion = "2.1.92" (must match UA)
  - gateway_billing_block.go: computeClaudeCodeFingerprint +
    buildBillingAttributionBlockJSON + extractFirstUserText
  - gateway_service.go: rewriteSystemForNonClaudeCode now emits both
    blocks in order; cch=00000 is filled in later by
    signBillingHeaderCCH in buildUpstreamRequest.

Downstream compat note: syncBillingHeaderVersion's regex
`cc_version=\d+\.\d+\.\d+` only matches the semver triple,
leaving the `.{fp}` suffix intact when rewriting in buildUpstreamRequest.
2026-04-24 23:16:32 +08:00
keh4l
66d6454535 feat(claude): add ttl to cache_control with default 5m
Real Claude CLI traffic sends cache_control as
`{"type":"ephemeral","ttl":"1h"}`. Our previous payload only
sent `{"type":"ephemeral"}`, which is a bytewise mismatch with
the official CLI and one more third-party detection signal.

Policy: client-provided ttl is always passed through unchanged.
Proxy-generated cache_control blocks default to 5m (vs Parrot's 1h)
to avoid burning the 1h cache budget on automatic breakpoints while
still aligning with the `ttl` field being present.

  - claude/constants.go: DefaultCacheControlTTL = "5m"
  - apicompat/types.go: new AnthropicCacheControl type with TTL field;
    AnthropicTool gains optional CacheControl pointer so the mimicry
    path can attach a cache breakpoint to tools[-1] later.
  - service/gateway_service.go: anthropicCacheControlPayload gains TTL;
    marshalAnthropicSystemTextBlock and rewriteSystemForNonClaudeCode
    emit ttl=5m by default.
2026-04-24 23:16:32 +08:00
keh4l
165553cfb0 fix(gateway): use full beta list in buildUpstreamRequest mimicry path
The previous commit added FullClaudeCodeMimicryBetas() but the two
call sites in buildUpstreamRequest still hardcoded the old 3-token
subset. Anthropic now checks the complete set of beta tokens to
decide if a request qualifies as Claude Code. Wire them up:

  - /v1/messages mimic path: requiredBetas = FullClaudeCodeMimicryBetas()
  - /v1/messages/count_tokens mimic path: same + BetaTokenCounting

Haiku models keep the 2-token exemption (BetaOAuth + InterleaveThinking).
2026-04-24 23:16:32 +08:00
keh4l
b5467d610a fix(gateway): apply full Claude Code mimicry on /chat/completions and /responses
Before: the OpenAI-compat forwarders only called injectClaudeCodePrompt,
which prepends the Claude Code banner but leaves the rest of the body
in its original non-Claude-Code shape. The codebase already admits this
is insufficient (see the comment on rewriteSystemForNonClaudeCode in
gateway_service.go: "仅前置追加 Claude Code 提示词无法通过检测").

Effect: OAuth accounts served through /v1/chat/completions or /v1/responses
were detected as third-party apps and bled plan quota with:

    Third-party apps now draw from your extra usage, not your plan limits.

Fix:
  - apicompat.AnthropicRequest: add Metadata json.RawMessage so metadata
    survives the OpenAI->Anthropic->Marshal round trip; without it the
    downstream rewrite has no user_id to work with.
  - service: extract applyClaudeCodeOAuthMimicryToBody, a ParsedRequest-free
    variant of the /v1/messages mimicry pipeline
    (rewriteSystemForNonClaudeCode + normalizeClaudeOAuthRequestBody +
    metadata.user_id injection) so the OpenAI-compat forwarders can reuse it.
  - service: add buildOAuthMetadataUserIDFromBody + hashBodyForSessionSeed
    for the same reason (no ParsedRequest at the call site).
  - ForwardAsChatCompletions / ForwardAsResponses: replace the 3-line
    prompt-prepend with the full mimicry pipeline.
  - applyClaudeCodeMimicHeaders: set x-client-request-id per-request
    (real Claude CLI always does); missing/duplicated values are one more
    third-party fingerprint signal.

No change to the native /v1/messages path: it already called the full
pipeline, we only lift those helpers into a reusable function.

Tests:
  - go build ./... passes
  - go test ./internal/service/... ./internal/pkg/apicompat/... passes
  - lsp_diagnostics clean on all touched files
  - pre-existing failures in internal/config are unrelated (env-sensitive
    tests that also fail on upstream main)
2026-04-24 23:16:32 +08:00
keh4l
57ff97960d chore(claude): bump mimicked CLI to 2.1.92 and extend anthropic-beta list
Align Claude Code mimicry constants with the latest real CLI traffic
(see Parrot's src/transform/cc_mimicry.py). Anthropic now uses the full
set of anthropic-beta tokens to decide whether a request counts as
"official Claude Code"; requests missing tokens that real CLI ships
today are demoted to third-party usage:

  Third-party apps now draw from your extra usage, not your plan limits.

Changes:
  - claude/constants.go: add new beta tokens (prompt-caching-scope,
    effort, redact-thinking, context-management, extended-cache-ttl) and
    expose FullClaudeCodeMimicryBetas() for the OAuth mimicry path.
  - claude/constants.go: bump default User-Agent to claude-cli/2.1.92.
  - identity_service.go: bump defaultFingerprint User-Agent accordingly.

No behavioral change for clients that already send a newer UA (fingerprint
merge still prefers the incoming value).
2026-04-24 23:16:32 +08:00
Wesley Liddick
5b5db88550 Merge pull request #1897 from VpSanta33/codex/invite-affiliate-rebate
feat: 新增邀请返利功能,并支持后台配置返利比例
2026-04-24 22:36:53 +08:00
VpSanta33
f03de00cb9 feat: add affiliate invite rebate flow and admin rebate-rate setting 2026-04-24 22:22:26 +08:00
Wesley Liddick
76aae5aa74 Merge pull request #1911 from gaoren002/fix/codex-responses-payload-normalization-mainbase
fix(openai): normalize codex responses payloads
2026-04-24 21:37:32 +08:00
gaoren002
27ee141c1e fix(openai): preserve mcp tool call ids 2026-04-24 13:24:21 +00:00
gaoren002
e65574dea9 fix(openai): normalize codex responses payloads 2026-04-24 12:03:19 +00:00
Wesley Liddick
1ce9dc03f9 Merge pull request #1895 from gaoren002/fix/codex-spark-limitations
fix(openai): handle codex spark model limitations
2026-04-24 19:57:42 +08:00
Wesley Liddick
15ce914a62 Merge pull request #1910 from slovx2/fix/codex-tool-call-ids
fix(openai): 修复 Codex 工具调用 call_id 处理
2026-04-24 19:56:03 +08:00
song
959af1c8f6 fix(openai): preserve codex tool call ids 2026-04-24 19:31:49 +08:00
gaoren002
c4d496da18 fix(openai): handle codex spark model limitations 2026-04-24 07:42:31 +00:00
KnowSky404
f3ea878ba2 chore: trigger PR checks 2026-04-24 11:32:41 +08:00
KnowSky404
d80469ea35 test: fix OpenAI account test helper calls after rebase 2026-04-24 11:32:41 +08:00
KnowSky404
5fc30ea964 test: cover openai admin test state transitions 2026-04-24 11:32:41 +08:00
KnowSky404
f68909a68b fix: reconcile openai admin test rate-limit state 2026-04-24 11:32:41 +08:00
github-actions[bot]
d162604f32 chore: sync VERSION to 0.1.117 [skip ci] 2026-04-24 01:40:02 +00:00
137 changed files with 10152 additions and 516 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
docs/claude-relay-service/
.codex
# ===================
# Go 后端

View File

@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@@ -1 +1 @@
0.1.116
0.1.118

View File

@@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
affiliateRepository := repository.NewAffiliateRepository(client, db)
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -192,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
@@ -221,21 +226,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
sqlDB, err := repository.ProvideSQLDB(client)
if err != nil {
return nil, err
}
channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -245,9 +242,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -263,6 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,

View File

@@ -652,6 +652,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
Mode string `json:"mode"`
}
type SyncFromCRSRequest struct {
@@ -682,7 +683,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
// Error already sent via SSE, just log
return
}

View File

@@ -0,0 +1,183 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AffiliateHandler handles admin affiliate (邀请返利) management:
// listing users with custom settings, updating per-user invite codes
// and exclusive rebate rates, and batch operations.
type AffiliateHandler struct {
affiliateService *service.AffiliateService
adminService service.AdminService
}
// NewAffiliateHandler creates a new admin affiliate handler.
func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
return &AffiliateHandler{
affiliateService: affiliateService,
adminService: adminService,
}
}
// ListUsers returns paginated users with custom affiliate settings.
// GET /api/v1/admin/affiliates/users
func (h *AffiliateHandler) ListUsers(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
search := c.Query("search")
entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
Search: search,
Page: page,
PageSize: pageSize,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, entries, total, page, pageSize)
}
// UpdateUserSettings updates a user's affiliate settings.
// PUT /api/v1/admin/affiliates/users/:user_id
//
// Both fields are optional and applied independently.
type UpdateAffiliateUserRequest struct {
AffCode *string `json:"aff_code"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
// ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
// Used to disambiguate from "field not provided".
ClearRebateRate bool `json:"clear_rebate_rate"`
}
func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
var req UpdateAffiliateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.AffCode != nil {
if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
response.ErrorFrom(c, err)
return
}
}
if req.ClearRebateRate {
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
response.ErrorFrom(c, err)
return
}
} else if req.AffRebateRatePercent != nil {
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
response.ErrorFrom(c, err)
return
}
}
response.Success(c, gin.H{"user_id": userID})
}
// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
// the exclusive rebate rate AND regenerates the invite code as a new system
// random one. Conceptually this "removes the user from the custom list".
//
// Both writes happen in this handler; failure of one leaves the other applied,
// but the operation is idempotent so the admin can re-run it safely.
// DELETE /api/v1/admin/affiliates/users/:user_id
func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
response.ErrorFrom(c, err)
return
}
if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"user_id": userID})
}
// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
//
// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
// ignored). Otherwise aff_rebate_rate_percent is required and applied to
// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
// can't distinguish a missing field from `null`, and a silent clear from a
// frontend that forgot to include the rate would be a footgun.
//
// POST /api/v1/admin/affiliates/users/batch-rate
type BatchSetRateRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
Clear bool `json:"clear"`
}
func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
var req BatchSetRateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.BadRequest(c, "user_ids cannot be empty")
return
}
if !req.Clear && req.AffRebateRatePercent == nil {
response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
return
}
rate := req.AffRebateRatePercent
if req.Clear {
rate = nil
}
if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"affected": len(req.UserIDs)})
}
// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
// shared with the frontend's add-custom-user picker.
type AffiliateUserSummary struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
}
// LookupUsers searches users by email/username for the "add custom user" modal.
// GET /api/v1/admin/affiliates/users/lookup?q=
func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []AffiliateUserSummary{})
return
}
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
}
result := make([]AffiliateUserSummary, len(users))
for i, u := range users {
result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
}
response.Success(c, result)
}

View File

@@ -185,6 +185,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
AffiliateRebateRate: settings.AffiliateRebateRate,
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
@@ -241,6 +245,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
@@ -338,6 +344,10 @@ type UpdateSettingsRequest struct {
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
@@ -439,6 +449,9 @@ type UpdateSettingsRequest struct {
// Available Channels feature switch (user-facing)
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
}
// UpdateSettings 更新系统设置
@@ -468,6 +481,43 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
affiliateRebateRate := previousSettings.AffiliateRebateRate
if req.AffiliateRebateRate != nil {
affiliateRebateRate = *req.AffiliateRebateRate
}
if affiliateRebateRate < service.AffiliateRebateRateMin {
affiliateRebateRate = service.AffiliateRebateRateMin
}
if affiliateRebateRate > service.AffiliateRebateRateMax {
affiliateRebateRate = service.AffiliateRebateRateMax
}
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
if req.AffiliateRebateFreezeHours != nil {
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
}
if affiliateRebateFreezeHours < 0 {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
}
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
}
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
if req.AffiliateRebateDurationDays != nil {
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
}
if affiliateRebateDurationDays < 0 {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
}
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
}
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
if req.AffiliateRebatePerInviteeCap != nil {
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
}
if affiliateRebatePerInviteeCap < 0 {
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -1119,6 +1169,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
@@ -1252,6 +1306,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AvailableChannelsEnabled
}(),
AffiliateEnabled: func() bool {
if req.AffiliateEnabled != nil {
return *req.AffiliateEnabled
}
return previousSettings.AffiliateEnabled
}(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
@@ -1433,6 +1493,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -1488,6 +1552,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@@ -1738,6 +1804,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.AffiliateRebateRate != after.AffiliateRebateRate {
changed = append(changed, "affiliate_rebate_rate")
}
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
changed = append(changed, "affiliate_rebate_freeze_hours")
}
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
changed = append(changed, "affiliate_rebate_duration_days")
}
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
changed = append(changed, "affiliate_rebate_per_invitee_cap")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
@@ -1853,6 +1931,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
changed = append(changed, "available_channels_enabled")
}
if before.AffiliateEnabled != after.AffiliateEnabled {
changed = append(changed, "affiliate_enabled")
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}

View File

@@ -48,6 +48,7 @@ type RegisterRequest struct {
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
AffCode string `json:"aff_code"` // 邀请返利码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
_, user, err := h.authService.RegisterWithVerification(
c.Request.Context(),
req.Email,
req.Password,
req.VerifyCode,
req.PromoCode,
req.InvitationCode,
req.AffCode,
)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
strings.TrimSpace(req.AffCode),
); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {

View File

@@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
nil,
nil,
options.defaultSubAssigner,
nil,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService

View File

@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()

View File

@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
nil,
nil,
nil,
nil,
)
return &AuthHandler{

View File

@@ -106,10 +106,14 @@ type SystemSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -191,6 +195,9 @@ type SystemSettings struct {
// Available Channels feature switch (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
}
type DefaultSubscriptionSetting struct {
@@ -243,6 +250,8 @@ type PublicSettings struct {
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO

View File

@@ -34,6 +34,7 @@ type AdminHandlers struct {
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
Payment *admin.PaymentHandler
Affiliate *admin.AffiliateHandler
}
// Handlers contains all HTTP handlers

View File

@@ -130,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
@@ -153,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)

View File

@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)

View File

@@ -238,6 +238,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
requireCompact := isOpenAIRemoteCompactPath(c)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -256,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
requireCompact,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
@@ -263,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
}
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@@ -644,6 +650,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
@@ -1167,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))

View File

@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
Save(context.Background())
require.NoError(t, err)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()

View File

@@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
})
}

View File

@@ -14,10 +14,11 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
@@ -26,12 +27,14 @@ func NewUserHandler(
authService *service.AuthService,
emailService *service.EmailService,
emailCache service.EmailCache,
affiliateService *service.AffiliateService,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
}
}
@@ -159,6 +162,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, profileResp)
}
// GetAffiliate returns the current user's affiliate details.
// GET /api/v1/user/aff
func (h *UserHandler) GetAffiliate(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
// TransferAffiliateQuota transfers all available affiliate quota into current balance.
// POST /api/v1/user/aff/transfer
func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"transferred_quota": transferred,
"balance": balance,
})
}
type StartIdentityBindingRequest struct {
Provider string `json:"provider" binding:"required"`
RedirectTo string `json:"redirect_to"`

View File

@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()

View File

@@ -37,6 +37,7 @@ func ProvideAdminHandlers(
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
Payment: paymentHandler,
Affiliate: affiliateHandler,
}
}
@@ -169,6 +171,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
admin.NewAffiliateHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
@@ -42,17 +43,55 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
if config[k] == "" {
if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
cfg := make(map[string]string, len(config))
for k, v := range config {
cfg[k] = v
}
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{
instanceID: instanceID,
config: config,
config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
func normalizeEasyPayAPIBase(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return ""
}
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.RawPath = ""
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
return strings.TrimRight(parsed.String(), "/")
}
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
}
func trimEasyPayEndpointPath(path string) string {
path = strings.TrimRight(strings.TrimSpace(path), "/")
lower := strings.ToLower(path)
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
if strings.HasSuffix(lower, endpoint) {
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
}
}
return path
}
func (e *EasyPay) apiBase() string {
if e == nil {
return ""
}
return normalizeEasyPayAPIBase(e.config["apiBase"])
}
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params {
q.Set(k, v)
}
base := strings.TrimRight(e.config["apiBase"], "/")
payURL := base + "/submit.php?" + q.Encode()
payURL := e.apiBase() + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
params := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"],
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
attempts := e.refundAttempts(req)
if len(attempts) == 0 {
return nil, fmt.Errorf("easypay refund missing order identifier")
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
if err != nil {
return nil, fmt.Errorf("easypay refund: %w", err)
var firstErr error
for i, attempt := range attempts {
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
if err != nil {
return nil, fmt.Errorf("easypay refund request: %w", err)
}
if err := parseEasyPayRefundResponse(status, body); err != nil {
if firstErr == nil {
firstErr = err
}
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
continue
}
return nil, err
}
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
return nil, firstErr
}
type easyPayRefundAttempt struct {
params map[string]string
refundID string
}
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
base := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
}
var attempts []easyPayRefundAttempt
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
params := cloneStringMap(base)
params["out_trade_no"] = orderID
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
}
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
params := cloneStringMap(base)
params["trade_no"] = tradeNo
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
}
return attempts
}
func cloneStringMap(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func isEasyPayRefundOrderNotFound(err error) bool {
if err == nil {
return false
}
msg := err.Error()
lower := strings.ToLower(msg)
return strings.Contains(msg, "订单编号不存在") ||
strings.Contains(msg, "订单不存在") ||
strings.Contains(lower, "order not found") ||
strings.Contains(lower, "not exist")
}
func parseEasyPayRefundResponse(status int, body []byte) error {
summary := summarizeEasyPayResponse(body)
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
}
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
}
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
var resp struct {
Code int `json:"code"`
Code any `json:"code"`
Msg string `json:"msg"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse refund: %w", err)
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
if !easyPayResponseCodeIsSuccess(resp.Code) {
msg := strings.TrimSpace(resp.Msg)
if msg == "" {
msg = summary
}
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
}
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
return nil
}
func easyPayResponseCodeIsSuccess(code any) bool {
switch v := code.(type) {
case float64:
return int(v) == easypayCodeSuccess
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
return err == nil && n == easypayCodeSuccess
default:
return false
}
}
func summarizeEasyPayResponse(body []byte) string {
summary := strings.Join(strings.Fields(string(body)), " ")
if summary == "" {
return "<empty>"
}
if len(summary) > maxEasypayErrorSummary {
return summary[:maxEasypayErrorSummary] + "..."
}
return summary
}
func (e *EasyPay) resolveCID(paymentType string) string {
@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
body, _, err := e.postRaw(ctx, endpoint, params)
return body, err
}
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := e.httpClient.Do(req)
client := e.httpClient
if client == nil {
client = &http.Client{Timeout: easypayHTTPTimeout}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
if err != nil {
return nil, resp.StatusCode, err
}
return body, resp.StatusCode, nil
}
func easyPaySign(params map[string]string, pkey string) string {

View File

@@ -0,0 +1,196 @@
package provider
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeEasyPayAPIBase(t *testing.T) {
t.Parallel()
tests := []struct {
input string
want string
}{
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
t.Parallel()
var gotPath string
var gotQuery url.Values
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQuery = r.URL.Query()
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
t.Fatalf("Refund response = %+v, want success", resp)
}
if gotPath != "/api.php" {
t.Fatalf("refund path = %q, want /api.php", gotPath)
}
if gotQuery.Get("act") != "refund" {
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
}
for key, want := range map[string]string{
"pid": "pid-1",
"key": "pkey-1",
"out_trade_no": "out-456",
"money": "1.50",
} {
if got := gotForm.Get(key); got != want {
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
}
}
if got := gotForm.Get("trade_no"); got != "" {
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
}
}
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
t.Parallel()
var gotForms []url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api.php" {
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
}
if r.URL.Query().Get("act") != "refund" {
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
}
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForms = append(gotForms, r.PostForm)
w.Header().Set("Content-Type", "application/json")
if len(gotForms) == 1 {
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
return
}
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
}
if len(gotForms) != 2 {
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
}
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
}
if got := gotForms[0].Get("trade_no"); got != "" {
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
}
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
}
if got := gotForms[1].Get("out_trade_no"); got != "" {
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
}
}
func TestEasyPayRefundResponseErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
want string
}{
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL)
_, err := provider.Refund(context.Background(), payment.RefundRequest{
OrderID: "out-456",
Amount: "1.50",
})
if err == nil {
t.Fatal("Refund returned nil error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
}
})
}
}
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
t.Helper()
provider, err := NewEasyPay("test-instance", map[string]string{
"pid": "pid-1",
"pkey": "pkey-1",
"apiBase": apiBase,
"notifyUrl": "https://example.com/notify",
"returnUrl": "https://example.com/return",
})
if err != nil {
t.Fatalf("NewEasyPay: %v", err)
}
return provider
}

View File

@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached",
Model: "gpt-5.2",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Cached response"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 3318, anth.Usage.InputTokens)
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 123, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached_clamp",
Model: "gpt-5.2",
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 5,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 150,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 0, anth.Usage.InputTokens)
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
@@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, 3318, events[0].Usage.InputTokens)
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
assert.Equal(t, 123, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()

View File

@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
out.Usage = AnthropicUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil {
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
}
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
return out
}
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
if usage == nil {
return AnthropicUsage{}
}
cachedTokens := 0
if usage.InputTokensDetails != nil {
cachedTokens = usage.InputTokensDetails.CachedTokens
}
inputTokens := usage.InputTokens - cachedTokens
if inputTokens < 0 {
inputTokens = 0
}
return AnthropicUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: cachedTokens,
}
}
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
@@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
state.InputTokens = evt.Response.Usage.InputTokens
state.OutputTokens = evt.Response.Usage.OutputTokens
if evt.Response.Usage.InputTokensDetails != nil {
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
}
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
state.InputTokens = usage.InputTokens
state.OutputTokens = usage.OutputTokens
state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":

View File

@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
case "web_search", "google_search", "web_search_20250305":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",

View File

@@ -12,17 +12,23 @@ import "encoding/json"
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
// Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
// 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
// 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
// RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
// user_id进而导致请求被归类为第三方 app。
Metadata json.RawMessage `json:"metadata,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
}
// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
type AnthropicCacheControl struct {
Type string `json:"type"` // "ephemeral"
TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m由 Anthropic 判定)
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.

View File

@@ -4,6 +4,12 @@ package claude
// Claude Code 客户端相关常量
// Beta header 常量
//
// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04
// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
// 原因Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
// 缺少任何"官方 Claude Code 请求才会带"的 beta都会被降级到第三方额度
// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
@@ -12,6 +18,13 @@ const (
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
// 新增(对齐官方 CLI 2.1.9x 以来的流量)
BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
BetaEffort = "effort-2025-11-24"
BetaRedactThinking = "redact-thinking-2026-02-12"
BetaContextManagement = "context-management-2025-06-27"
BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
const DefaultCacheControlTTL = "5m"
// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver
// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
const CLICurrentVersion = "2.1.92"
// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
// 用于 OAuth 账号伪装成 Claude Code 时使用。
// 顺序与真实 CLI 抓包一致。
//
// 使用建议:
// - OAuth 账号 + 非 haiku追加这整份列表再按需保留 client 带来的 beta。
// - OAuth 账号 + haikuAnthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
func FullClaudeCodeMimicryBetas() []string {
return []string{
BetaClaudeCode,
BetaOAuth,
BetaInterleavedThinking,
BetaPromptCachingScope,
BetaEffort,
BetaRedactThinking,
BetaContextManagement,
BetaExtendedCacheTTL,
}
}
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
// 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
"User-Agent": "claude-cli/2.1.92 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",

View File

@@ -0,0 +1,14 @@
package repository
import "testing"
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
updates := map[string]any{
"openai_compact_supported": true,
"openai_compact_checked_at": "2026-04-10T10:00:00Z",
}
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
}
}

View File

@@ -0,0 +1,762 @@
package repository
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
const (
affiliateCodeLength = 12
affiliateCodeMaxAttempts = 12
)
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type affiliateRepository struct {
client *dbent.Client
}
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
return &affiliateRepository{client: client}
}
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
return ensureUserAffiliateWithClient(ctx, client, userID)
}
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
client := clientFromContext(ctx, r.client)
return queryAffiliateByCode(ctx, client, code)
}
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
var bound bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
return err
}
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
inviterID, userID,
)
if err != nil {
return fmt.Errorf("bind inviter: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
bound = false
return nil
}
if _, err = txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
inviterID,
); err != nil {
return fmt.Errorf("increment inviter aff_count: %w", err)
}
bound = true
return nil
})
if err != nil {
return false, err
}
return bound, nil
}
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
if amount <= 0 {
return false, nil
}
var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
// freezeHours > 0: add to frozen quota; == 0: add to available quota directly
var updateSQL string
if freezeHours > 0 {
updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
} else {
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
}
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
if err != nil {
return err
}
affected, _ := res.RowsAffected()
if affected == 0 {
applied = false
return nil
}
if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
} else {
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
}
applied = true
return nil
})
if err != nil {
return false, err
}
return applied, nil
}
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx,
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
inviterID, inviteeUserID)
if err != nil {
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
}
defer func() { _ = rows.Close() }()
var total float64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
return total, rows.Close()
}
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
var thawed float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
var err error
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
return err
})
return thawed, err
}
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
rows, err := txClient.QueryContext(txCtx, `
WITH matured AS (
UPDATE user_affiliate_ledger
SET frozen_until = NULL, updated_at = NOW()
WHERE user_id = $1
AND frozen_until IS NOT NULL
AND frozen_until <= NOW()
RETURNING amount
)
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
if err != nil {
return 0, fmt.Errorf("thaw frozen quota: %w", err)
}
defer func() { _ = rows.Close() }()
var thawed float64
if rows.Next() {
if err := rows.Scan(&thawed); err != nil {
return 0, err
}
}
if err := rows.Close(); err != nil {
return 0, err
}
if thawed <= 0 {
return 0, nil
}
_, err = txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_quota = aff_quota + $1,
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
updated_at = NOW()
WHERE user_id = $2`, thawed, userID)
if err != nil {
return 0, fmt.Errorf("move thawed quota: %w", err)
}
return thawed, nil
}
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64
var newBalance float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
// Thaw any matured frozen quota before transfer.
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
return fmt.Errorf("thaw before transfer: %w", err)
}
rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS (
SELECT aff_quota::double precision AS amount
FROM user_affiliates
WHERE user_id = $1
AND aff_quota > 0
FOR UPDATE
),
cleared AS (
UPDATE user_affiliates ua
SET aff_quota = 0,
updated_at = NOW()
FROM claimed c
WHERE ua.user_id = $1
RETURNING c.amount
)
SELECT amount
FROM cleared`, userID)
if err != nil {
return fmt.Errorf("claim affiliate quota: %w", err)
}
if !rows.Next() {
_ = rows.Close()
if err := rows.Err(); err != nil {
return err
}
return service.ErrAffiliateQuotaEmpty
}
if err := rows.Scan(&transferred); err != nil {
_ = rows.Close()
return err
}
if err := rows.Close(); err != nil {
return err
}
if transferred <= 0 {
return service.ErrAffiliateQuotaEmpty
}
affected, err := txClient.User.Update().
Where(user.IDEQ(userID)).
AddBalance(transferred).
AddTotalRecharged(transferred).
Save(txCtx)
if err != nil {
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
}
if affected == 0 {
return service.ErrUserNotFound
}
newBalance, err = queryUserBalance(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
}
return nil
})
if err != nil {
return 0, 0, err
}
return transferred, newBalance, nil
}
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
if limit <= 0 {
limit = 100
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.created_at,
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = $1
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
WHERE ua.inviter_id = $1
GROUP BY ua.user_id, u.email, u.username, ua.created_at
ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
invitees := make([]service.AffiliateInvitee, 0)
for rows.Next() {
var item service.AffiliateInvitee
var createdAt time.Time
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
return nil, err
}
item.CreatedAt = &createdAt
invitees = append(invitees, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return invitees, nil
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client())
}
tx, err := r.client.Tx(ctx)
if err != nil {
return fmt.Errorf("begin affiliate transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx, tx.Client()); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit affiliate transaction: %w", err)
}
return nil
}
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
summary, err := queryAffiliateByUserID(ctx, client, userID)
if err == nil {
return summary, nil
}
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
return nil, err
}
for i := 0; i < affiliateCodeMaxAttempts; i++ {
code, codeErr := generateAffiliateCode()
if codeErr != nil {
return nil, codeErr
}
_, insertErr := client.ExecContext(ctx, `
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING`, userID, code)
if insertErr == nil {
break
}
if isAffiliateUniqueViolation(insertErr) {
continue
}
return nil, insertErr
}
return queryAffiliateByUserID(ctx, client, userID)
}
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
aff_code_custom,
aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&out.AffCodeCustom,
&rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
if rebateRate.Valid {
v := rebateRate.Float64
out.AffRebateRatePercent = &v
}
return &out, nil
}
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
aff_code_custom,
aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE aff_code = $1
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&out.AffCodeCustom,
&rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
if rebateRate.Valid {
v := rebateRate.Float64
out.AffRebateRatePercent = &v
}
return &out, nil
}
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
rows, err := client.QueryContext(ctx,
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
userID,
)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return 0, err
}
return 0, service.ErrUserNotFound
}
var balance float64
if err := rows.Scan(&balance); err != nil {
return 0, err
}
return balance, nil
}
func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("generate affiliate code: %w", err)
}
for i := range buf {
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
}
return string(buf), nil
}
func isAffiliateUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return string(pqErr.Code) == "23505"
}
return false
}
// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
// 唯一性冲突返回 ErrAffiliateCodeTaken。
func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
if userID <= 0 {
return service.ErrUserNotFound
}
code := strings.ToUpper(strings.TrimSpace(newCode))
if code == "" {
return service.ErrAffiliateCodeInvalid
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_code = $1,
aff_code_custom = true,
updated_at = NOW()
WHERE user_id = $2`, code, userID)
if err != nil {
if isAffiliateUniqueViolation(err) {
return service.ErrAffiliateCodeTaken
}
return fmt.Errorf("update aff_code: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
return nil
})
}
// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
if userID <= 0 {
return "", service.ErrUserNotFound
}
var newCode string
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
for i := 0; i < affiliateCodeMaxAttempts; i++ {
candidate, codeErr := generateAffiliateCode()
if codeErr != nil {
return codeErr
}
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_code = $1,
aff_code_custom = false,
updated_at = NOW()
WHERE user_id = $2`, candidate, userID)
if err != nil {
if isAffiliateUniqueViolation(err) {
continue
}
return fmt.Errorf("reset aff_code: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
newCode = candidate
return nil
}
return fmt.Errorf("reset aff_code: exhausted attempts")
})
if err != nil {
return "", err
}
return newCode, nil
}
// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
if userID <= 0 {
return service.ErrUserNotFound
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
// nullableArg lets us use a single UPDATE for both "set value" and
// "clear" cases — database/sql converts nil interface{} to SQL NULL.
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_rebate_rate_percent = $1,
updated_at = NOW()
WHERE user_id = $2`, nullableArg(ratePercent), userID)
if err != nil {
return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
return nil
})
}
// BatchSetUserRebateRate 批量为多个用户设置专属比例nil 清除)。
func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
if len(userIDs) == 0 {
return nil
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
for _, uid := range userIDs {
if uid <= 0 {
continue
}
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
return err
}
}
_, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_rebate_rate_percent = $1,
updated_at = NOW()
WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
if err != nil {
return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
}
return nil
})
}
// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
// binding: nil pointer → SQL NULL, non-nil → the float value.
func nullableArg(v *float64) any {
if v == nil {
return nil
}
return *v
}
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
//
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索"
// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
// 这避免了为两种情况维护两份 SQL 模板。
func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
page := filter.Page
if page < 1 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 || pageSize > 200 {
pageSize = 20
}
offset := (page - 1) * pageSize
likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
const baseFrom = `
FROM user_affiliates ua
JOIN users u ON u.id = ua.user_id
WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
AND (u.email ILIKE $1 OR u.username ILIKE $1)`
client := clientFromContext(ctx, r.client)
total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
if err != nil {
return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
}
listQuery := `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.aff_code,
ua.aff_code_custom,
ua.aff_rebate_rate_percent,
ua.aff_count` + baseFrom + `
ORDER BY ua.updated_at DESC
LIMIT $2 OFFSET $3`
rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
if err != nil {
return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
}
defer func() { _ = rows.Close() }()
entries := make([]service.AffiliateAdminEntry, 0)
for rows.Next() {
var e service.AffiliateAdminEntry
var rebate sql.NullFloat64
if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
&e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
return nil, 0, err
}
if rebate.Valid {
v := rebate.Float64
e.AffRebateRatePercent = &v
}
entries = append(entries, e)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return entries, total, nil
}
// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
rows, err := client.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return 0, err
}
return 0, nil
}
var v int64
if err := rows.Scan(&v); err != nil {
return 0, err
}
return v, nil
}

View File

@@ -0,0 +1,399 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value float64
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value int
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 5.5,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.NoError(t, err)
require.InDelta(t, 12.34, transferred, 1e-9)
require.InDelta(t, 17.84, balance, 1e-9)
affQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
require.InDelta(t, 0.0, affQuota, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 17.84, persistedBalance, 1e-9)
ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount)
}
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
// that already carries a transaction (via dbent.NewTxContext), repo.withTx
// must reuse that tx rather than opening a nested one. If this invariant
// breaks, AccrueQuota would commit independently and survive a rollback of
// the outer tx, which would violate payment_fulfillment's all-or-nothing
// semantics.
func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
ctx := context.Background()
outerTx, err := integrationEntClient.Tx(ctx)
require.NoError(t, err, "begin outer tx")
// Defensive cleanup: if any require.* below fires before the explicit
// Rollback, this prevents the tx from leaking until container teardown.
// Rollback is idempotent at the driver level (extra rollback returns an
// error we ignore).
t.Cleanup(func() { _ = outerTx.Rollback() })
client := outerTx.Client()
txCtx := dbent.NewTxContext(ctx, outerTx)
inviter := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
invitee := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
repo := NewAffiliateRepository(client, integrationDB)
_, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
require.NoError(t, err)
_, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
require.NoError(t, err)
bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true")
// Visible inside the outer tx.
innerQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
require.InDelta(t, 3.5, innerQuota, 1e-9)
// Roll back the outer tx; if AccrueQuota had opened its own inner tx and
// committed it, the rows would still be visible to the global client.
require.NoError(t, outerTx.Rollback())
rows, err := integrationEntClient.QueryContext(ctx,
"SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
inviter.ID, invitee.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next())
var postRollbackCount int
require.NoError(t, rows.Scan(&postRollbackCount))
require.Equal(t, 0, postRollbackCount,
"AccrueQuota must propagate the outer tx — found persisted rows after rollback")
}
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 3.21,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
require.InDelta(t, 0.0, transferred, 1e-9)
require.InDelta(t, 0.0, balance, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 3.21, persistedBalance, 1e-9)
}
// TestAffiliateRepository_AdminCustomCode covers the success path of admin
// invite-code rewrite + reset within a shared test transaction:
// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
// - the old code can no longer be found
// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
//
// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
// test because a unique-violation aborts the surrounding Postgres tx, which
// would poison subsequent assertions in the same transaction.
func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
originalCode := original.AffCode
// Rewrite to a custom code
customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.Equal(t, customCode, updated.AffCode)
require.True(t, updated.AffCodeCustom)
// Lookup by new custom code finds the user
byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
require.NoError(t, err)
require.Equal(t, u.ID, byCode.UserID)
// Old system code should no longer match
_, err = repo.GetAffiliateByCode(txCtx, originalCode)
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
// Reset back to a fresh system code, clears custom flag
newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
require.NoError(t, err)
require.NotEqual(t, customCode, newSysCode)
reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.Equal(t, newSysCode, reset.AffCode)
require.False(t, reset.AffCodeCustom)
// The old custom code is now free again
_, err = repo.GetAffiliateByCode(txCtx, customCode)
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
}
// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
// this test must be the only assertion and run in its own tx — production
// callers each have their own outer tx, so this matches real behavior.
func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
taker := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
requester := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
// Now requester tries to grab the same code → conflict.
err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
}
// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
// set/clear and the Batch variant including NULL semantics.
func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u1 := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
u2 := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
// Set exclusive rate for u1
rate := 42.5
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
require.NoError(t, err)
require.NotNil(t, got.AffRebateRatePercent)
require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
// Clear exclusive rate
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
require.NoError(t, err)
require.Nil(t, cleared.AffRebateRatePercent)
// Batch set both users
batchRate := 15.0
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
for _, uid := range []int64{u1.ID, u2.ID} {
v, err := repo.EnsureUserAffiliate(txCtx, uid)
require.NoError(t, err)
require.NotNil(t, v.AffRebateRatePercent)
require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
}
// Batch clear
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
for _, uid := range []int64{u1.ID, u2.ID} {
v, err := repo.EnsureUserAffiliate(txCtx, uid)
require.NoError(t, err)
require.Nil(t, v.AffRebateRatePercent)
}
}
// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
// only includes users with at least one override applied.
func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
// User without any custom config — should NOT appear in the list.
plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
uPlain := mustCreateUser(t, client, &service.User{
Email: plainEmail, PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
_, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
require.NoError(t, err)
// User with a custom code — should appear.
uCode := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
// User with only an exclusive rate — should appear.
uRate := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
r := 33.3
require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
Page: 1, PageSize: 100,
})
require.NoError(t, err)
// Build a quick lookup to assert per-user attributes (other tests may have
// inserted custom rows in the same DB; we only care about our 3).
byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
for _, e := range entries {
byUserID[e.UserID] = e
}
require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
codeEntry, ok := byUserID[uCode.ID]
require.True(t, ok, "custom-code user missing from list")
require.True(t, codeEntry.AffCodeCustom)
require.Nil(t, codeEntry.AffRebateRatePercent)
rateEntry, ok := byUserID[uRate.ID]
require.True(t, ok, "custom-rate user missing from list")
require.False(t, rateEntry.AffCodeCustom)
require.NotNil(t, rateEntry.AffRebateRatePercent)
require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
}

View File

@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository,
// Cache implementations
NewGatewayCache,

View File

@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,

View File

@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,

View File

@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}

View File

@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
// 邀请返利(专属用户管理)
registerAffiliateRoutes(admin, h)
}
}
@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates")
{
users := affiliates.Group("/users")
{
users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
}
}
}

View File

@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
user.GET("/aff", h.User.GetAffiliate)
user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)

View File

@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
const (
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
OpenAICompactModeAuto = "auto"
// OpenAICompactModeForceOn always treats the account as compact-supported.
OpenAICompactModeForceOn = "force_on"
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
OpenAICompactModeForceOff = "force_off"
)
func normalizeOpenAICompactMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAICompactModeForceOn:
return OpenAICompactModeForceOn
case OpenAICompactModeForceOff:
return OpenAICompactModeForceOff
default:
return OpenAICompactModeAuto
}
}
func stringMappingFromRaw(raw any) map[string]string {
switch mapping := raw.(type) {
case map[string]any:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
if str, ok := value.(string); ok {
result[key] = str
}
}
if len(result) == 0 {
return nil
}
return result
case map[string]string:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
result[key] = value
}
return result
default:
return nil
}
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
// Missing or invalid values fall back to "auto".
func (a *Account) GetOpenAICompactMode() string {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return OpenAICompactModeAuto
}
mode, _ := a.Extra["openai_compact_mode"].(string)
return normalizeOpenAICompactMode(mode)
}
// OpenAICompactSupportKnown reports whether compact capability is known for this
// account and, when known, whether it is supported.
func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
if a == nil || !a.IsOpenAI() {
return false, false
}
switch a.GetOpenAICompactMode() {
case OpenAICompactModeForceOn:
return true, true
case OpenAICompactModeForceOff:
return false, true
}
if a.Extra == nil {
return false, false
}
supported, ok := a.Extra["openai_compact_supported"].(bool)
if !ok {
return false, false
}
return supported, true
}
// AllowsOpenAICompact reports whether the account may be considered for compact
// requests. Unknown capability remains allowed to avoid breaking older accounts
// before an explicit probe has been run.
func (a *Account) AllowsOpenAICompact() bool {
if a == nil || !a.IsOpenAI() {
return false
}
supported, known := a.OpenAICompactSupportKnown()
if !known {
return true
}
return supported
}
// GetCompactModelMapping returns compact-only model remapping configuration.
// This mapping is intended for /responses/compact only and does not affect
// normal /responses traffic.
func (a *Account) GetCompactModelMapping() map[string]string {
if a == nil || a.Credentials == nil {
return nil
}
return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
}
// ResolveCompactMappedModel resolves compact-only model remapping and reports
// whether a compact-specific mapping rule matched.
func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetCompactModelMapping()
if len(mapping) == 0 {
return requestedModel, false
}
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
return mappedModel, true
}
return requestedModel, false
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""

View File

@@ -0,0 +1,369 @@
package service
import "testing"
func TestAccountGetOpenAICompactMode(t *testing.T) {
tests := []struct {
name string
account *Account
want string
}{
{
name: "nil account defaults to auto",
want: OpenAICompactModeAuto,
},
{
name: "non openai account defaults to auto",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: OpenAICompactModeAuto,
},
{
name: "missing extra defaults to auto",
account: &Account{
Platform: PlatformOpenAI,
},
want: OpenAICompactModeAuto,
},
{
name: "invalid mode falls back to auto",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " invalid "},
},
want: OpenAICompactModeAuto,
},
{
name: "force on is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
},
want: OpenAICompactModeForceOn,
},
{
name: "force off is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": "force_off"},
},
want: OpenAICompactModeForceOff,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.GetOpenAICompactMode(); got != tt.want {
t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
}
})
}
}
func TestAccountOpenAICompactSupportKnown(t *testing.T) {
tests := []struct {
name string
account *Account
wantSupported bool
wantKnown bool
}{
{
name: "nil account is unknown",
wantSupported: false,
wantKnown: false,
},
{
name: "non openai account is unknown",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: false,
wantKnown: false,
},
{
name: "force on overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOn,
"openai_compact_supported": false,
},
},
wantSupported: true,
wantKnown: true,
},
{
name: "force off overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOff,
"openai_compact_supported": true,
},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto true is known supported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: true,
wantKnown: true,
},
{
name: "auto false is known unsupported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto without probe state remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
wantSupported: false,
wantKnown: false,
},
{
name: "invalid probe field remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": "true"},
},
wantSupported: false,
wantKnown: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
}
})
}
}
func TestAccountAllowsOpenAICompact(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "nil account does not allow compact",
want: false,
},
{
name: "non openai account does not allow compact",
account: &Account{
Platform: PlatformAnthropic,
},
want: false,
},
{
name: "unknown openai account remains allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
want: true,
},
{
name: "supported openai account is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
want: true,
},
{
name: "unsupported openai account is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
want: false,
},
{
name: "force on is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: true,
},
{
name: "force off is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.AllowsOpenAICompact(); got != tt.want {
t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
}
})
}
}
func TestAccountGetCompactModelMapping(t *testing.T) {
tests := []struct {
name string
account *Account
want map[string]string
}{
{
name: "nil account returns nil",
want: nil,
},
{
name: "missing credentials returns nil",
account: &Account{
Platform: PlatformOpenAI,
},
want: nil,
},
{
name: "map any is converted",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
"invalid": 1,
},
},
},
want: map[string]string{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
{
name: "map string string is copied",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]string{
"gpt-*": "compact-*",
},
},
},
want: map[string]string{
"gpt-*": "compact-*",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.GetCompactModelMapping()
if !equalStringMap(got, tt.want) {
t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestAccountResolveCompactMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no compact mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact compact mapping matches",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
expectedMatch: true,
},
{
name: "exact passthrough counts as match",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "longest wildcard wins",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-*": "fallback-compact",
"gpt-5.4*": "gpt-5.4-openai-compact",
"gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
},
},
requestedModel: "gpt-5.4-mini",
expectedModel: "gpt-5.4-mini-openai-compact",
expectedMatch: true,
},
{
name: "missing compact mapping reports unmatched",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.3": "gpt-5.3-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Credentials: tt.credentials,
}
gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func equalStringMap(left, right map[string]string) bool {
if len(left) != len(right) {
return false
}
for key, want := range right {
if got, ok := left[key]; !ok || got != want {
return false
}
}
return true
}

View File

@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID, prompt)
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
// Align test routing with gateway behavior: OpenAI accounts apply normal
// account model mapping, and compact mode applies compact-only mapping on top.
testModelID = account.GetMappedModel(testModelID)
if mode == AccountTestModeCompact {
testModelID = resolveOpenAICompactForwardModel(account, testModelID)
return s.testOpenAICompactConnection(c, account, testModelID)
}
// Route to image generation test if an image model is selected
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
ctx := c.Request.Context()
authToken := ""
apiURL := ""
isOAuth := false
chatgptAccountID := ""
switch {
case account.IsOAuth():
isOAuth = true
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
apiURL = chatgptCodexAPIURL + "/compact"
chatgptAccountID = account.GetChatGPTAccountID()
case account.Type == AccountTypeAPIKey:
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("Originator", "codex_cli_rs")
req.Header.Set("User-Agent", codexCLIUserAgent)
req.Header.Set("Version", codexCLIVersion)
probeSessionID := compactProbeSessionID(account.ID)
req.Header.Set("Session_ID", probeSessionID)
req.Header.Set("Conversation_ID", probeSessionID)
if isOAuth {
req.Host = "chatgpt.com"
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
updates = mergeExtraUpdates(updates, codexUpdates)
}
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
var resetAt *time.Time
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
resetAt = calculated
} else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
t := time.Unix(*unixTs, 0)
resetAt = &t
}
if resetAt == nil {
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
return
}
now := time.Now()
account.RateLimitedAt = &now
account.RateLimitResetAt = resetAt
if account.Status == StatusError {
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
return
}
account.Status = StatusActive
account.ErrorMessage = ""
}
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
case "response.completed":
case "response.completed", "response.done":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "response.failed":
errorMsg := "OpenAI response failed"
if responseData, ok := data["response"].(map[string]any); ok {
if errData, ok := responseData["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok && msg != "" {
errorMsg = msg
}
}
}
return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()

View File

@@ -0,0 +1,199 @@
package service
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusNotFound,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`404 page not found`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.Error(t, err)
updates := <-updateCalls
require.Equal(t, false, updates["openai_compact_supported"])
require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"error"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 3,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://example.com/v1",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 4,
Name: "openai-apikey-default",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
<-updateCalls
}

View File

@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
clearedErrorID int64
setErrorID int64
setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
r.clearedErrorID = id
return nil
}
func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
`))
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 90,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Contains(t, recorder.Body.String(), "response.completed")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
require.Empty(t, repo.updatedExtra)
}
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 78,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 79,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "stale 403",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusError, account.Status)
require.Equal(t, "stale 403", account.ErrorMessage)
require.Nil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 80,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
require.Zero(t, repo.rateLimitedID)
require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}

View File

@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
openAICodexProbeVersion = "0.104.0"
openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存

View File

@@ -0,0 +1,490 @@
package service
import (
"context"
"errors"
"math"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
)
const (
affiliateInviteesLimit = 100
// AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
// 12-char codes and admin-customized codes (e.g. "VIP2026").
AffiliateCodeMinLength = 4
AffiliateCodeMaxLength = 32
)
// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
// All input passes through strings.ToUpper before validation, so lowercase from
// users is normalized — admins may supply mixed case in their UI.
var affiliateCodeValidChar = func() [256]bool {
var tbl [256]bool
for c := byte('A'); c <= 'Z'; c++ {
tbl[c] = true
}
for c := byte('0'); c <= '9'; c++ {
tbl[c] = true
}
tbl['_'] = true
tbl['-'] = true
return tbl
}()
// isValidAffiliateCodeFormat validates code format for both binding (user input)
// and admin updates. Caller is expected to upper-case the input first.
func isValidAffiliateCodeFormat(code string) bool {
if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
return false
}
for i := 0; i < len(code); i++ {
if !affiliateCodeValidChar[code[i]] {
return false
}
}
return true
}
type AffiliateSummary struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
AffCodeCustom bool `json:"aff_code_custom"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AffiliateInvitee struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
CreatedAt *time.Time `json:"created_at,omitempty"`
TotalRebate float64 `json:"total_rebate"`
}
type AffiliateDetail struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例aff_rebate_rate_percent否则回退到全局比例。
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
Invitees []AffiliateInvitee `json:"invitees"`
}
type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
// 管理端:用户级专属配置
UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
ResetUserAffCode(ctx context.Context, userID int64) (string, error)
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
}
// AffiliateAdminFilter 列表筛选条件
type AffiliateAdminFilter struct {
Search string
Page int
PageSize int
}
// AffiliateAdminEntry 专属用户列表条目
type AffiliateAdminEntry struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
AffCode string `json:"aff_code"`
AffCodeCustom bool `json:"aff_code_custom"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
AffCount int `json:"aff_count"`
}
type AffiliateService struct {
repo AffiliateRepository
settingService *SettingService
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCacheService *BillingCacheService
}
func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
return &AffiliateService{
repo: repo,
settingService: settingService,
authCacheInvalidator: authCacheInvalidator,
billingCacheService: billingCacheService,
}
}
// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
if s == nil || s.settingService == nil {
return AffiliateEnabledDefault
}
return s.settingService.IsAffiliateEnabled(ctx)
}
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.EnsureUserAffiliate(ctx, userID)
}
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
// Lazy thaw: move any matured frozen quota to available before reading.
if s != nil && s.repo != nil {
// best-effort: thaw failure is non-fatal
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
}
summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil {
return nil, err
}
invitees, err := s.listInvitees(ctx, userID)
if err != nil {
return nil, err
}
return &AffiliateDetail{
UserID: summary.UserID,
AffCode: summary.AffCode,
InviterID: summary.InviterID,
AffCount: summary.AffCount,
AffQuota: summary.AffQuota,
AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota,
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees,
}, nil
}
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
code := strings.ToUpper(strings.TrimSpace(rawCode))
if code == "" {
return nil
}
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
// 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
if !s.IsEnabled(ctx) {
return nil
}
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
if err != nil {
return err
}
if selfSummary.InviterID != nil {
return nil
}
inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
if err != nil {
if errors.Is(err, ErrAffiliateProfileNotFound) {
return ErrAffiliateCodeInvalid
}
return err
}
if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
return ErrAffiliateCodeInvalid
}
bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
if err != nil {
return err
}
if !bound {
return ErrAffiliateAlreadyBound
}
return nil
}
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
if s == nil || s.repo == nil {
return 0, nil
}
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
return 0, nil
}
// 总开关关闭时,新充值不再产生返利
if !s.IsEnabled(ctx) {
return 0, nil
}
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
if err != nil {
return 0, err
}
if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
return 0, nil
}
// 加载邀请人 profile优先使用专属比例覆盖全局
inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
if err != nil {
return 0, err
}
// 有效期检查:超过返利有效期后不再产生返利
if s.settingService != nil {
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
return 0, nil
}
}
}
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
// 单人上限检查:精确截断到剩余额度
if s.settingService != nil {
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
if err != nil {
return 0, err
}
if existing >= perInviteeCap {
return 0, nil
}
if remaining := perInviteeCap - existing; rebate > remaining {
rebate = roundTo(remaining, 8)
}
}
}
var freezeHours int
if s.settingService != nil {
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
if err != nil {
return 0, err
}
if !applied {
return 0, nil
}
return rebate, nil
}
// resolveRebateRatePercent returns the inviter's exclusive rate when set,
// otherwise the global setting value (clamped to [Min, Max]).
func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
if inviter != nil && inviter.AffRebateRatePercent != nil {
v := *inviter.AffRebateRatePercent
if math.IsNaN(v) || math.IsInf(v, 0) {
return s.globalRebateRatePercent(ctx)
}
return clampAffiliateRebateRate(v)
}
return s.globalRebateRatePercent(ctx)
}
// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
// returning the documented default when SettingService is unavailable.
func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
if s == nil || s.settingService == nil {
return AffiliateRebateRateDefault
}
return s.settingService.GetAffiliateRebateRatePercent(ctx)
}
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
if s == nil || s.repo == nil {
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
if err != nil {
return 0, 0, err
}
if transferred > 0 {
s.invalidateAffiliateCaches(ctx, userID)
}
return transferred, balance, nil
}
func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
if err != nil {
return nil, err
}
for i := range invitees {
invitees[i].Email = maskEmail(invitees[i].Email)
}
return invitees, nil
}
func roundTo(v float64, scale int) float64 {
factor := math.Pow10(scale)
return math.Round(v*factor) / factor
}
func maskEmail(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return ""
}
at := strings.Index(email, "@")
if at <= 0 || at >= len(email)-1 {
return "***"
}
local := email[:at]
domain := email[at+1:]
dot := strings.LastIndex(domain, ".")
maskedLocal := maskSegment(local)
if dot <= 0 || dot >= len(domain)-1 {
return maskedLocal + "@" + maskSegment(domain)
}
domainName := domain[:dot]
tld := domain[dot:]
return maskedLocal + "@" + maskSegment(domainName) + tld
}
func maskSegment(s string) string {
r := []rune(s)
if len(r) == 0 {
return "***"
}
if len(r) == 1 {
return string(r[0]) + "***"
}
return string(r[0]) + "***"
}
func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
}
}
}
// =========================
// Admin: 专属配置管理
// =========================
// validateExclusiveRate ensures a per-user override is finite and within
// [Min, Max]. nil is always valid (means "clear / fall back to global").
func validateExclusiveRate(ratePercent *float64) error {
if ratePercent == nil {
return nil
}
v := *ratePercent
if math.IsNaN(v) || math.IsInf(v, 0) {
return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
}
if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
}
return nil
}
// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
code := strings.ToUpper(strings.TrimSpace(rawCode))
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
return s.repo.UpdateUserAffCode(ctx, userID, code)
}
// AdminResetUserAffCode 重置用户邀请码为系统随机码。
func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
if s == nil || s.repo == nil {
return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ResetUserAffCode(ctx, userID)
}
// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
if err := validateExclusiveRate(ratePercent); err != nil {
return err
}
return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
}
// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
if err := validateExclusiveRate(ratePercent); err != nil {
return err
}
cleaned := make([]int64, 0, len(userIDs))
for _, uid := range userIDs {
if uid > 0 {
cleaned = append(cleaned, uid)
}
}
if len(cleaned) == 0 {
return nil
}
return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
}
// AdminListCustomUsers 列出有专属配置的用户。
func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListUsersWithCustomSettings(ctx, filter)
}

View File

@@ -0,0 +1,131 @@
//go:build unit
package service
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/require"
)
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
// global rate, and that out-of-range exclusive rates are clamped silently.
//
// SettingService is left nil here so globalRebateRatePercent returns the
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
// fallback path without spinning up a settings stub.
func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
// nil exclusive rate → falls back to global default (20%)
require.InDelta(t, AffiliateRebateRateDefault,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
// exclusive rate set → overrides global
rate := 50.0
require.InDelta(t, 50.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
// exclusive rate 0 → returns 0 (no rebate, intentional)
zero := 0.0
require.InDelta(t, 0.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
// exclusive rate above max → clamped to Max
tooHigh := 250.0
require.InDelta(t, AffiliateRebateRateMax,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
// exclusive rate below min → clamped to Min
tooLow := -5.0
require.InDelta(t, AffiliateRebateRateMin,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
}
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
// safely handles a nil settingService dependency by returning the default
// (off). This protects callers from nil-pointer crashes in misconfigured
// environments.
func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
require.False(t, svc.IsEnabled(context.Background()))
require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
}
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
// admin-facing rate setters: nil is always valid (clear), in-range values
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
t.Parallel()
require.NoError(t, validateExclusiveRate(nil))
for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
v := v
require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
}
for _, v := range []float64{-0.01, 100.01, -100, 200} {
v := v
require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
}
nan := math.NaN()
require.Error(t, validateExclusiveRate(&nan))
posInf := math.Inf(1)
require.Error(t, validateExclusiveRate(&posInf))
negInf := math.Inf(-1)
require.Error(t, validateExclusiveRate(&negInf))
}
func TestMaskEmail(t *testing.T) {
t.Parallel()
require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
require.Equal(t, "x***@d***", maskEmail("x@domain"))
require.Equal(t, "", maskEmail(""))
}
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
// 邀请码格式校验同时服务于:
// 1) 系统自动生成的 12 位随机码A-Z 去 I/O2-9 去 0/1
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1"
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper
cases := []struct {
name string
in string
want bool
}{
{"valid canonical 12-char", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
{"valid admin custom short", "VIP1", true},
{"valid admin custom with hyphen", "NEW-USER", true},
{"valid admin custom with underscore", "VIP_2026", true},
{"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
{"letter I now allowed", "IBCDEFGHJKLM", true},
{"letter O now allowed", "OBCDEFGHJKLM", true},
{"digit 0 now allowed", "0BCDEFGHJKLM", true},
{"digit 1 now allowed", "1BCDEFGHJKLM", true},
{"too short (3 chars)", "ABC", false},
{"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
{"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
{"ascii punctuation .", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
})
}
}

View File

@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User,
invitationCode string,
signupSource string,
affiliateCode string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}

View File

@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
nil,
)
}

View File

@@ -72,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -98,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -110,6 +112,7 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
}
// RegisterWithVerification 用户注册支持邮件验证、优惠码和邀请码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if s.affiliateService != nil {
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
// 邀请返利码绑定失败不影响注册,只记录日志
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
}
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -549,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -652,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -669,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -763,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
}
}
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
if s.affiliateService == nil || userID <= 0 {
return
}
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
}
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return

View File

@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,

View File

@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}

View File

@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}
@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
@@ -621,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
@@ -657,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)

View File

@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}

View File

@@ -18,6 +18,19 @@ const (
RoleUser = domain.RoleUser
)
// Affiliate rebate settings
const (
AffiliateRebateRateDefault = 20.0
AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
AffiliateRebateDurationDaysMax = 3650 // ~10 年
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
)
// Platform constants
const (
PlatformAnthropic = domain.PlatformAnthropic
@@ -87,6 +100,11 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例百分比0-100
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期小时0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限0=无上限)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址

View File

@@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.True(t, system.IsArray(), "system should be an array")
require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
arr := system.Array()
require.Len(t, arr, 2, "system array should have billing block + cc prompt block")
require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:")
require.Contains(t, arr[0].Get("text").String(), "cc_version=")
require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String())
require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String())
// 原始 system prompt 应迁移至 messages 中
messages := gjson.GetBytes(upstream.lastBody, "messages")

View File

@@ -0,0 +1,98 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
)
// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。
//
// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致;
// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致,
// 进一步触发 Anthropic 的第三方检测。
const fingerprintSalt = "59cf53e54c78"
// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法:
//
// 1. 取 messages 中第一条 role=user 的纯文本(首块 text
// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐)
// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符
//
// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint与官方 CLI 字节对齐。
// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。
func computeClaudeCodeFingerprint(body []byte, version string) string {
firstText := extractFirstUserText(body)
indices := []int{4, 7, 20}
chars := make([]byte, 0, 3)
for _, i := range indices {
if i < len(firstText) {
chars = append(chars, firstText[i])
} else {
chars = append(chars, '0')
}
}
sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version))
return hex.EncodeToString(sum[:])[:3]
}
// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。
// 兼容 string 和 []block 两种 content 格式。
func extractFirstUserText(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
first := ""
messages.ForEach(func(_, msg gjson.Result) bool {
if msg.Get("role").String() != "user" {
return true
}
content := msg.Get("content")
if content.Type == gjson.String {
first = content.String()
return false
}
if content.IsArray() {
content.ForEach(func(_, block gjson.Result) bool {
if block.Get("type").String() == "text" {
first = block.Get("text").String()
return false
}
return true
})
return false
}
return false
})
return first
}
// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。
//
// 形态严格对齐真实 Claude Code CLI
//
// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"}
//
// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段
// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。
//
// 此 block 不带 cache_control与真实 CLI 一致cache breakpoint 由后续的
// Claude Code prompt block 承担)。
func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) {
if cliVersion == "" {
return nil, fmt.Errorf("cliVersion required")
}
fp := computeClaudeCodeFingerprint(body, cliVersion)
text := fmt.Sprintf(
"x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;",
cliVersion, fp,
)
return json.Marshal(map[string]string{
"type": "text",
"text": text,
})
}

View File

@@ -41,12 +41,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`)
require.Contains(t, resultStr, `"temperature":0.2`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
require.Contains(t, resultStr, `"max_tokens":128000`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {

View File

@@ -85,15 +85,16 @@ func (s *GatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts.
// Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -312,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, ccResp)
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, ccResp)
}
return &ForwardResult{
RequestID: requestID,
@@ -383,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
// Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
// c 可能持有请求侧注入的 ToolNameRewrite无则仅做静态前缀还原。
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
return true // client disconnected
}
return false

View File

@@ -82,15 +82,16 @@ func (s *GatewayService) ForwardAsResponses(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints).
// OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -331,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
if respBytes, err := json.Marshal(responsesResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, responsesResp)
}
return &ForwardResult{
RequestID: requestID,
@@ -419,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
@@ -439,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
fmt.Fprint(c.Writer, out) //nolint:errcheck
}
c.Writer.Flush()
}

View File

@@ -0,0 +1,141 @@
package service
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
// 与 Parrot _strip_message_cache_control 语义一致。
//
// 为什么必须整体清空:客户端(特别是 Claude Code经常把 cache_control 打在
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
// 变成中间某条cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
// 统一由代理重新打断点addMessageCacheBreakpoints才能在多轮间稳定。
func stripMessageCacheControl(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
msgIdx := -1
messages.ForEach(func(_, msg gjson.Result) bool {
msgIdx++
content := msg.Get("content")
if !content.IsArray() {
return true
}
blockIdx := -1
content.ForEach(func(_, block gjson.Result) bool {
blockIdx++
if !block.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
if next, err := sjson.DeleteBytes(body, path); err == nil {
body = next
}
return true
})
return true
})
return body
}
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
// 1. 最后一条 message
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
//
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
// + tools[-1] 的断点共同构成最多 4 个断点Anthropic 上限)。
//
// cache_control ttl 策略:
// - 若目标 block 已有 cache_control.ttl → 不覆盖
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
func addMessageCacheBreakpoints(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
arr := messages.Array()
if len(arr) == 0 {
return body
}
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
if len(arr) >= 4 {
userCount := 0
for i := len(arr) - 1; i >= 0; i-- {
if arr[i].Get("role").String() != "user" {
continue
}
userCount++
if userCount == 2 {
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
break
}
}
}
return body
}
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
// 的最后一个 content block 上。若 content 是 string先升级成单块 text 数组
// (对齐 Parrot _inject_cache_on_msg 的行为)。
//
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
content := msg.Get("content")
if content.Type == gjson.String {
text := content.String()
blockRaw := fmt.Sprintf(
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
mustJSONString(text), claude.DefaultCacheControlTTL,
)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
body = next
}
return body
}
if !content.IsArray() {
return body
}
contentArr := content.Array()
if len(contentArr) == 0 {
return body
}
lastBlockIdx := len(contentArr) - 1
lastBlock := contentArr[lastBlockIdx]
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
return body
}
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
existingCC := lastBlock.Get("cache_control")
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
body = next
}
return body
}
// mustJSONString 把一个 Go string 序列化为合法 JSON string含引号
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
func mustJSONString(s string) string {
return fmt.Sprintf("%q", s)
}

View File

@@ -9,6 +9,11 @@ import (
)
func TestIsClaudeCodeClient(t *testing.T) {
// 合法的 legacy 格式 metadata.user_id64位 hex + account uuid + session uuid
legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
// 合法的 JSON 格式 metadata.user_id2.1.78+ 版本)
jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}`
tests := []struct {
name string
userAgent string
@@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) {
want bool
}{
{
name: "Claude Code client",
name: "Claude Code client with legacy user_id",
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
metadataUserID: legacyUserID,
want: true,
},
{
name: "Claude Code without version suffix",
userAgent: "claude-cli/2.0.0",
metadataUserID: "session_abc",
name: "Claude Code client with JSON user_id",
userAgent: "claude-cli/2.1.92 (external, cli)",
metadataUserID: jsonUserID,
want: true,
},
{
name: "Claude Code case insensitive UA",
userAgent: "Claude-CLI/2.0.0",
metadataUserID: legacyUserID,
want: true,
},
{
@@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) {
want: false,
},
{
name: "Different user agent",
name: "Claude CLI UA with invalid user_id format",
userAgent: "claude-cli/2.0.0",
metadataUserID: "fake-user-id-12345",
want: false,
},
{
name: "Different user agent with valid user_id",
userAgent: "curl/7.68.0",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Empty user agent",
userAgent: "",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Similar but not Claude CLI",
userAgent: "claude-api/1.0.0",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Opencode spoofing UA with arbitrary user_id",
userAgent: "claude-cli/2.1.92",
metadataUserID: "session_abc",
want: false,
},
}
@@ -378,16 +401,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
// system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
// system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution block (x-anthropic-billing-header: cc_version=...;)
// [1] Claude Code prompt block (带 cache_control)
systemArr, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array, got %T", parsed["system"])
require.Len(t, systemArr, 1, "system array should have exactly 1 block")
systemBlock, ok := systemArr[0].(map[string]any)
require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)")
billingBlock, ok := systemArr[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", billingBlock["type"])
require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:")
require.Contains(t, billingBlock["text"], "cc_version=")
require.Contains(t, billingBlock["text"], "cc_entrypoint=cli")
require.Contains(t, billingBlock["text"], "cch=00000")
systemBlock, ok := systemArr[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", systemBlock["type"])
require.Equal(t, tt.wantSystemText, systemBlock["text"])
cc, ok := systemBlock["cache_control"].(map[string]any)
require.True(t, ok, "system block should have cache_control")
require.True(t, ok, "cc prompt block should have cache_control")
require.Equal(t, "ephemeral", cc["type"])
// 检查 messages

View File

@@ -119,7 +119,7 @@ func openAIStreamEventIsTerminal(data string) bool {
return true
}
switch gjson.Get(trimmed, "type").String() {
case "response.completed", "response.done", "response.failed":
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
@@ -329,7 +329,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -850,6 +850,7 @@ func (s *GatewayService) hashContent(content string) string {
type anthropicCacheControlPayload struct {
Type string `json:"type"`
TTL string `json:"ttl,omitempty"`
}
type anthropicSystemTextBlockPayload struct {
@@ -898,7 +899,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b
Text: text,
}
if includeCacheControl {
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
block.CacheControl = &anthropicCacheControlPayload{
Type: "ephemeral",
TTL: claude.DefaultCacheControlTTL,
}
}
return json.Marshal(block)
}
@@ -1074,19 +1078,52 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if gjson.GetBytes(out, "temperature").Exists() {
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
// temperature真实 Claude Code CLI 总是发送 temperature默认 1客户端可覆盖
// 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。
// 策略:客户端传了什么就透传;没传则补默认 1。
if !gjson.GetBytes(out, "temperature").Exists() {
if next, ok := setJSONValueBytes(out, "temperature", 1); ok {
out = next
modified = true
}
}
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
// max_tokens真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok {
out = next
modified = true
}
}
// context_managementthinking.type 为 enabled/adaptive 时,真实 CLI 会自动
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
// 客户端显式传了就透传;否则按 CLI 行为补齐。
if !gjson.GetBytes(out, "context_management").Exists() {
thinkingType := gjson.GetBytes(out, "thinking.type").String()
if thinkingType == "enabled" || thinkingType == "adaptive" {
const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}`
if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok {
out = next
modified = true
}
}
}
// tool_choice与 Parrot 对齐,不再无条件删除。
// - 客户端传了 {"type":"tool","name":"X"} → 保留结构name 由
// applyToolNameRewriteToBody 同步映射为假名
// - 其他形态auto/any/none原样透传
// 如果 body 里完全没有 tools空数组tool_choice 没意义时才删除
if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
}
if !modified {
return body, modelID
}
@@ -1128,6 +1165,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号"
// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。
//
// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode +
// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层
// (ForwardAsChatCompletions / ForwardAsResponses) 复用。
//
// 未抽离之前OpenAI 协议兼容层仅做 injectClaudeCodePrompt前置追加
// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic
// 第三方检测";那条注释就是本函数存在的根因。
//
// 参数:
// - ctx / c用于读取指纹和 gateway settingsc 可为 nil如 count_tokens
// - account必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。
// - body已经 marshal 成 Anthropic /v1/messages 格式的请求体。
// - systemRawbody 中原始 system 字段(用于判断是否需要 rewrite
// - model最终会发给上游的模型 ID用于 haiku 旁路 + metadata 版本选择)。
//
// 返回:改写后的 body。即使中间任何一步失败也会退化成原 body不会 panic
func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
systemRaw any,
model string,
) []byte {
if account == nil || !account.IsOAuth() || len(body) == 0 {
return body
}
systemRewritten := false
if !strings.Contains(strings.ToLower(model), "haiku") {
body = rewriteSystemForNonClaudeCode(body, systemRaw)
systemRewritten = true
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil && c != nil && c.Request != nil {
if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil {
mimicMPT := false
if s.settingService != nil {
_, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
}
if !mimicMPT {
if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = uid
}
}
}
}
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
// 1) strip先清除客户端的 messages[*].cache_control多轮稳定性
// 2) breakpoints再注入 2 个断点(最后一条 + 倒数第二个 user turn
// 3) tool rewrite最后改 tools[*].name / tool_choice.name 并在 tools[-1]
// 上打断点mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
if c != nil {
c.Set(toolNameRewriteKey, rw)
}
} else {
body = applyToolsLastCacheBreakpoint(body)
}
return body
}
// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体,
// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。
//
// 与 buildOAuthMetadataUserID 的唯一区别:
// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。
// - 如果 body 里已经存在 metadata.user_id则返回空由 ensureClaudeOAuthMetadataUserID
// 自行决定是否覆盖)。
func (s *GatewayService) buildOAuthMetadataUserIDFromBody(
ctx context.Context,
account *Account,
fp *Fingerprint,
body []byte,
) string {
_ = ctx
if account == nil {
return ""
}
if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
userID = generateClientID()
}
sessionID := uuid.NewString()
if hash := hashBodyForSessionSeed(body); hash != "" {
sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash))
}
var uaVersion string
if fp != nil {
uaVersion = ExtractCLIVersion(fp.UserAgent)
}
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。
// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。
func hashBodyForSessionSeed(body []byte) string {
if len(body) == 0 {
return ""
}
sum := sha256.Sum256(body)
return fmt.Sprintf("%x", sum[:16])
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
func GenerateSessionUUID(seed string) string {
return generateSessionUUID(seed)
@@ -3543,23 +3709,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断User-Agent 匹配 + metadata.user_id 存在
// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端
// 判定条件:
// 1. User-Agent 匹配 claude-cli/X.Y.Z大小写不敏感
// 2. metadata.user_id 符合 Claude Code 格式legacy 或 JSON 格式)
//
// 只检查 metadata.user_id 非空不够严格第三方工具opencode 等)可能伪造 UA
// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID
// 验证格式才能确认是真正的 Claude Code 客户端。
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
if metadataUserID == "" {
if !claudeCliUserAgentRe.MatchString(userAgent) {
return false
}
return claudeCliUserAgentRe.MatchString(userAgent)
}
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
return ParseMetadataUserID(metadataUserID) != nil
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型string / []any / nil
@@ -3738,17 +3900,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. system 替换为 Claude Code 标准提示词array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
// 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution blockcc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;
// [1] "You are Claude Code..." prompt block带 cache_control 作为稳定缓存断点)
//
// billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的
// signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload
// 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。
billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion)
ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if billingErr != nil || ccErr != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr)
return body
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock}))
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
@@ -3985,15 +4150,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
})
}
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
// Claude Code 客户端判定UA 匹配 claude-cli/* 且携带 metadata.user_id。
// 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header
// 不需要代理做任何 body 级别的 mimicry强行替换反而会破坏客户端的缓存策略
// (长 system prompt 被替换为 ~45 tokens 的短 prompt低于 Anthropic 1024 token
// 最低缓存门槛,导致系统级缓存失效)。
//
// 对于非 Claude Code 的第三方客户端opencode 等),仍然走完整 mimicry。
isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
// 与 Parrot 对齐OAuth 账号无条件重写 system即使客户端已发了 Claude Code
// 风格的 system prompt。原因第三方工具opencode 等)会发 "You are Claude
// Code..." system prompt 但缺少 billing attribution block导致 Anthropic
// 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。
// Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
if !strings.Contains(strings.ToLower(reqModel), "haiku") {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
@@ -4017,6 +4191,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
// D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
// 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
c.Set(toolNameRewriteKey, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
@@ -4955,7 +5141,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
}
if !clientDisconnected {
if _, err := io.WriteString(w, line); err != nil {
restored := string(reverseToolNamesIfPresent(c, []byte(line)))
if _, err := io.WriteString(w, restored); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
@@ -5125,6 +5312,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
if contentType == "" {
contentType = "application/json"
}
body = reverseToolNamesIfPresent(c, body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
@@ -5580,13 +5768,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
setHeaderRaw(req.Header, "x-api-key", token)
}
// 白名单透传headers(恢复真实 wire casing
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
// 白名单透传 headers
// OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。
// Parrot 的 build_upstream_headers 只发 9 个精确 header不透传任何客户端 header。
// 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent /
// x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。
if tokenType != "oauth" || !mimicClaudeCode {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
}
}
}
}
@@ -5627,7 +5821,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
@@ -6099,6 +6293,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if isStream {
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
}
// Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。
// 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。
if getHeaderRaw(req.Header, "x-client-request-id") == "" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString())
}
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -6864,7 +7063,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for _, block := range outputBlocks {
if !clientDisconnected {
if _, werr := fmt.Fprint(w, block); werr != nil {
restored := reverseToolNamesIfPresent(c, []byte(block))
if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
break
@@ -7206,6 +7406,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
body = reverseToolNamesIfPresent(c, body)
// 写入响应
c.Data(resp.StatusCode, contentType, body)
@@ -8194,12 +8396,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// Antigravity 账户不支持 count_tokens返回 404 让客户端 fallback 到本地估算。
@@ -8623,7 +8833,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")

View File

@@ -0,0 +1,313 @@
package service
import (
"fmt"
"hash/fnv"
"math/rand"
"sort"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
const toolNameRewriteKey = "claude_tool_name_rewrite"
// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
var staticToolNameRewrites = map[string]string{
"sessions_": "cc_sess_",
"session_": "cc_ses_",
}
// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
var fakeToolNamePrefixes = []string{
"analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
"process_", "query_", "render_", "resolve_", "sync_", "update_",
"validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
"review_", "search_", "transform_", "handle_", "invoke_", "notify_",
}
// dynamicToolMapThreshold 与 Parrot 一致tools 数量超过 5 才启用动态映射。
// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
const dynamicToolMapThreshold = 5
// ToolNameRewrite 是单次请求内的工具名混淆映射。
// - Forward: real → fake请求阶段在 body 上应用。
// - Reverse: fake → real响应阶段对每个 chunk 做 bytes.Replace 还原。
//
// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
type ToolNameRewrite struct {
Forward map[string]string
Reverse map[string]string
ReverseOrdered [][2]string
}
// buildDynamicToolMap 构造 tools 的动态假名映射。
//
// 与 Parrot _build_dynamic_tool_map 语义等价:
// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil不做动态映射走静态 fallback
// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
//
// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
// Go 无法字节级复刻 Python hash但"稳定性"和"前缀池打散"两个不变量都保留:
// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
// 字节级不同不影响上游判定Anthropic 不会验证我们的随机种子算法)。
func buildDynamicToolMap(toolNames []string) map[string]string {
if len(toolNames) <= dynamicToolMapThreshold {
return nil
}
h := fnv.New64a()
for i, n := range toolNames {
if i > 0 {
_, _ = h.Write([]byte{0})
}
_, _ = h.Write([]byte(n))
}
rng := rand.New(rand.NewSource(int64(h.Sum64())))
available := make([]string, len(fakeToolNamePrefixes))
copy(available, fakeToolNamePrefixes)
rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
mapping := make(map[string]string, len(toolNames))
for i, name := range toolNames {
prefix := available[i%len(available)]
headLen := 3
if len(name) < 3 {
headLen = len(name)
}
fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
mapping[name] = fake
}
return mapping
}
// sanitizeToolName 把真名转成假名。
// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
func sanitizeToolName(name string, dynamic map[string]string) string {
if dynamic != nil {
if fake, ok := dynamic[name]; ok {
return fake
}
}
for prefix, replacement := range staticToolNameRewrites {
if strings.HasPrefix(name, prefix) {
return replacement + name[len(prefix):]
}
}
return name
}
// shouldMimicToolName 指示某个 tool 是否需要重命名。
// server tooltype != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
func shouldMimicToolName(toolType string) bool {
if toolType == "" || toolType == "function" || toolType == "custom" {
return true
}
return false
}
// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name构造 ToolNameRewrite
// 并返回它。若不需要混淆tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
//
// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return nil
}
mimicableNames := make([]string, 0)
toolsArr := tools.Array()
for _, t := range toolsArr {
if !shouldMimicToolName(t.Get("type").String()) {
continue
}
name := t.Get("name").String()
if name == "" {
continue
}
mimicableNames = append(mimicableNames, name)
}
dynamic := buildDynamicToolMap(mimicableNames)
rw := &ToolNameRewrite{
Forward: make(map[string]string),
Reverse: make(map[string]string),
}
for _, name := range mimicableNames {
fake := sanitizeToolName(name, dynamic)
if fake == name {
continue
}
rw.Forward[name] = fake
rw.Reverse[fake] = name
}
if len(rw.Forward) == 0 {
return nil
}
rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
for fake, real := range rw.Reverse {
rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
}
sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
})
return rw
}
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
// - 改写 $.tools[*].name仅对 shouldMimicToolName 通过的 tool
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点Parrot 行为对齐,
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL
// - 改写 $.tool_choice.name仅当 $.tool_choice.type == "tool"
//
// 历史 $.messages[*].content[*].nametool_use不在请求侧改写——这与 Parrot 一致;
// 响应侧 bytes.Replace 会连带还原它们。
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
if rw == nil || len(rw.Forward) == 0 {
body = applyToolsLastCacheBreakpoint(body)
return body
}
tools := gjson.GetBytes(body, "tools")
if tools.IsArray() {
idx := -1
tools.ForEach(func(_, t gjson.Result) bool {
idx++
if !shouldMimicToolName(t.Get("type").String()) {
return true
}
name := t.Get("name").String()
if name == "" {
return true
}
fake, ok := rw.Forward[name]
if !ok {
return true
}
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
body = next
}
return true
})
}
if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
name := tc.Get("name").String()
if fake, ok := rw.Forward[name]; ok {
if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
body = next
}
}
}
body = applyToolsLastCacheBreakpoint(body)
return body
}
// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
// 行为,但 ttl 按本仓规则:
// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 纯副作用函数tools 不存在或为空数组时 no-op。
func applyToolsLastCacheBreakpoint(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return body
}
arr := tools.Array()
if len(arr) == 0 {
return body
}
lastIdx := len(arr) - 1
existingCC := arr[lastIdx].Get("cache_control")
if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
return body
}
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
body = next
}
return body
}
// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace防止子串冲突
// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
// 再做静态前缀还原cc_sess_ → sessions_ / cc_ses_ → session_
//
// rw 可为 nilnil 时仍会做静态前缀还原。
func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
if rw != nil {
for _, pair := range rw.ReverseOrdered {
fake, real := pair[0], pair[1]
if fake == "" || fake == real {
continue
}
data = replaceAllBytes(data, fake, real)
}
}
for prefix, replacement := range staticToolNameRewrites {
data = replaceAllBytes(data, replacement, prefix)
}
return data
}
// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
func replaceAllBytes(data []byte, from, to string) []byte {
if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
return data
}
return []byte(strings.ReplaceAll(string(data), from, to))
}
// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
// 找不到c==nil 或 key 不存在或类型不对)时返回 nil调用方必须能处理 nil。
func toolNameRewriteFromContext(c interface {
Get(string) (any, bool)
}) *ToolNameRewrite {
if c == nil {
return nil
}
raw, ok := c.Get(toolNameRewriteKey)
if !ok || raw == nil {
return nil
}
rw, _ := raw.(*ToolNameRewrite)
return rw
}
// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
func reverseToolNamesIfPresent(c interface {
Get(string) (any, bool)
}, chunk []byte) []byte {
rw := toolNameRewriteFromContext(c)
if rw == nil && len(staticToolNameRewrites) == 0 {
return chunk
}
return restoreToolNamesInBytes(chunk, rw)
}

View File

@@ -0,0 +1,185 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
// Parrot 行为tools 数量 ≤ 5 时不做动态映射。
names := []string{"bash", "edit", "read", "write", "search"}
require.Nil(t, buildDynamicToolMap(names))
}
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
a := buildDynamicToolMap(names)
b := buildDynamicToolMap(names)
require.NotNil(t, a)
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
require.Len(t, a, 6)
for _, name := range names {
require.Contains(t, a, name)
require.NotEqual(t, name, a[name])
}
}
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
require.Equal(t, "bash", sanitizeToolName("bash", nil))
}
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
dyn := map[string]string{"sessions_list": "analyze_ses00"}
got := sanitizeToolName("sessions_list", dyn)
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
}
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
rw := &ToolNameRewrite{
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
}
// 手工构造 ReverseOrdered长的在前
rw.ReverseOrdered = [][2]string{
{"abc_12_ext", "bar"},
{"abc_12", "foo"},
}
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
restored := string(restoreToolNamesInBytes(data, rw))
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
}
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
got := string(restoreToolNamesInBytes(data, nil))
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
}
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Contains(t, rw.Forward, "sessions_list")
require.Contains(t, rw.Forward, "session_get")
// web_search is a server tool, not rewritten
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
// tool_choice.name rewritten
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
// First tool untouched
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
}
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
out := applyToolsLastCacheBreakpoint(body)
// User-provided ttl must be preserved.
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
}
func TestStripMessageCacheControl(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
out := stripMessageCacheControl(body)
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := addMessageCacheBreakpoints(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
// Parrot 不变量messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
body := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"q1"}]},
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
{"role":"user","content":[{"type":"text","text":"q2"}]},
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
]}`)
out := addMessageCacheBreakpoints(body)
// 最后一条 assistant 被打断点
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
// 倒数第二个 user turn = index 0唯一另一个 user
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
// 其他不打断点
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
out := addMessageCacheBreakpoints(body)
// content 升级成数组
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
body := []byte(`{"tools":[
{"name":"t1","input_schema":{}},
{"name":"t2","input_schema":{}},
{"name":"t3","input_schema":{}},
{"name":"t4","input_schema":{}},
{"name":"t5","input_schema":{}},
{"name":"t6","input_schema":{}}
]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.NotEmpty(t, rw.ReverseOrdered)
for i := 1; i < len(rw.ReverseOrdered); i++ {
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
"ReverseOrdered must be sorted by fake-name length descending")
}
}
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
data := []byte("plain text without any tool names")
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
}
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
m := buildDynamicToolMap(names)
require.NotNil(t, m)
for _, name := range names {
fake, ok := m[name]
require.True(t, ok)
// fake = prefix + head3 + "%02d"
// ends with two decimal digits
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
head := name
if len(head) > 3 {
head = head[:3]
}
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
}
}

View File

@@ -26,7 +26,7 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.1.22 (external, cli)",
UserAgent: "claude-cli/2.1.92 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",

View File

@@ -3,7 +3,6 @@ package service
import (
"container/heap"
"context"
"errors"
"fmt"
"hash/fnv"
"math"
@@ -45,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
RequiredImageCapability OpenAIImagesCapability
RequireCompact bool
ExcludedIDs map[int64]struct{}
}
@@ -258,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select(
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
req.RequireCompact,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
selection = nil
}
}
@@ -348,8 +352,8 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
@@ -590,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
// require_privacy_set: 获取分组信息
@@ -630,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
loadMap := map[int64]*AccountLoadInfo{}
@@ -640,45 +644,14 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
@@ -686,53 +659,183 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
// 时作为最后兜底snapshot 可能已陈旧)。
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
candidateCount := len(candidates)
loadSkew := 0.0
if len(candidates) > 0 {
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
}
topK := 0
if len(candidates) > 0 {
topK = s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
}
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || topK <= 0 {
return nil
}
groupTopK := topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
for _, candidate := range candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
}
} else {
selectionOrder = buildSelectionOrder(candidates)
}
if len(selectionOrder) == 0 {
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
}
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
@@ -742,17 +845,25 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
@@ -761,10 +872,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked)
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
@@ -905,8 +1016,9 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
}
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
@@ -917,13 +1029,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
excludedIDs map[int64]struct{},
requiredCapability OpenAIImagesCapability,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
if err == nil && selection != nil && selection.Account != nil {
return selection, decision, nil
}
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basicOAuth 账号)
if requiredCapability == OpenAIImagesCapabilityNative {
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic)
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
}
return selection, decision, err
}
@@ -937,6 +1049,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requiredImageCapability OpenAIImagesCapability,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
@@ -945,7 +1058,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
@@ -970,7 +1083,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
@@ -1008,6 +1121,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
RequiredImageCapability: requiredImageCapability,
RequireCompact: requireCompact,
ExcludedIDs: excludedIDs,
})
}

View File

@@ -0,0 +1,195 @@
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown
// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91001)
accounts := []Account{
{
ID: 71001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown
},
{
ID: 71002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": true}, // tier=2
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown")
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported
// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91002)
accounts := []Account{
{
ID: 71010,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
{
ID: 71011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false},
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.Error(t, err)
require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error")
require.Nil(t, selection)
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown
// 验证当没有"已知支持"账号时compact 请求会回退到"未探测"账号。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91003)
accounts := []Account{
{
ID: 71020,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false}, // tier=0
},
{
ID: 71021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown -> tier=1
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available")
}
// TestOpenAICompactSupportTier 验证 tier 分类逻辑。
func TestOpenAICompactSupportTier(t *testing.T) {
tests := []struct {
name string
account *Account
want int
}{
{name: "nil", account: nil, want: 0},
{name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0},
{name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1},
{name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2},
{name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0},
{name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2},
{name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := openAICompactSupportTier(tt.account); got != tt.want {
t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want)
}
})
}
}

View File

@@ -289,6 +289,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -343,6 +344,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -384,6 +386,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
@@ -445,6 +448,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -486,7 +490,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -540,7 +544,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -616,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -662,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -740,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -788,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -857,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -900,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.Error(t, err)
require.Nil(t, selection)
@@ -976,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -1014,7 +1025,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
@@ -1218,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)

View File

@@ -54,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)

View File

@@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"fmt"
"strings"
)
@@ -48,6 +49,8 @@ type codexTransformResult struct {
const (
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
codexSparkImageUnsupportedMarker = "<sub2api-codex-spark-image-unsupported>"
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
@@ -151,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if normalizeCodexToolChoice(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
@@ -165,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if applyInstructions(reqBody, isCodexCLI) {
result.Modified = true
}
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
result.Modified = true
}
// 续链场景保留 item_reference 与 id避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
input = normalizedInput
result.Modified = true
}
if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
input = normalizedInput
result.Modified = true
}
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
@@ -192,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
return result
}
func normalizeCodexToolChoice(reqBody map[string]any) bool {
choice, ok := reqBody["tool_choice"]
if !ok || choice == nil {
return false
}
choiceMap, ok := choice.(map[string]any)
if !ok {
return false
}
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
return false
}
reqBody["tool_choice"] = "auto"
return true
}
func codexToolsContainType(rawTools any, toolType string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(toolType) == "" {
return false
}
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
normalized = append(normalized, item)
continue
}
role, _ := m["role"].(string)
if strings.TrimSpace(role) != "tool" {
normalized = append(normalized, item)
continue
}
callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
callID = strings.TrimSpace(callID)
if callID == "" {
// Responses does not accept role:"tool". If no call id is available,
// preserve the text as a user message instead of sending invalid input.
fallback := make(map[string]any, len(m))
for key, value := range m {
fallback[key] = value
}
fallback["role"] = "user"
delete(fallback, "tool_call_id")
normalized = append(normalized, fallback)
modified = true
continue
}
output := extractTextFromContent(m["content"])
if output == "" {
if value, ok := m["output"].(string); ok {
output = value
}
}
if output == "" && m["content"] != nil {
if b, err := json.Marshal(m["content"]); err == nil {
output = string(b)
}
}
normalized = append(normalized, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
modified = true
}
if !modified {
return input, false
}
return normalized, true
}
func normalizeCodexMessageContentText(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
normalized = append(normalized, item)
continue
}
parts, ok := m["content"].([]any)
if !ok {
normalized = append(normalized, item)
continue
}
var newItem map[string]any
var newParts []any
ensureItemCopy := func() {
if newItem != nil {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
newParts = make([]any, len(parts))
copy(newParts, parts)
}
for i, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
text, hasText := part["text"]
if !hasText {
continue
}
if _, ok := text.(string); ok {
continue
}
ensureItemCopy()
newPart := make(map[string]any, len(part))
for key, value := range part {
newPart[key] = value
}
newPart["text"] = stringifyCodexContentText(text)
newParts[i] = newPart
modified = true
}
if newItem != nil {
newItem["content"] = newParts
normalized = append(normalized, newItem)
continue
}
normalized = append(normalized, item)
}
if !modified {
return input, false
}
return normalized, true
}
func stringifyCodexContentText(value any) string {
switch v := value.(type) {
case string:
return v
case nil:
return ""
default:
if b, err := json.Marshal(v); err == nil {
return string(b)
}
return fmt.Sprint(v)
}
}
func normalizeCodexModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
@@ -244,6 +438,10 @@ func normalizeCodexModel(model string) string {
return "gpt-5.4"
}
func isCodexSparkModel(model string) bool {
return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
}
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -265,6 +463,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
return false
}
func hasOpenAIInputImage(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
}
func hasOpenAIInputImageValue(value any) bool {
switch v := value.(type) {
case []any:
for _, item := range v {
if hasOpenAIInputImageValue(item) {
return true
}
}
case map[string]any:
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
return true
}
if _, ok := v["image_url"]; ok {
return true
}
return hasOpenAIInputImageValue(v["content"])
}
return false
}
func validateCodexSparkInput(reqBody map[string]any, model string) error {
if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
return nil
}
return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
}
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -309,6 +541,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
tool := map[string]any{
"type": "image_generation",
@@ -344,6 +579,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexImageGenerationBridgeMarker) {
@@ -360,6 +598,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
return true
}
func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
return false
}
existing = strings.TrimRight(existing, " \t\r\n")
if strings.TrimSpace(existing) == "" {
reqBody["instructions"] = codexSparkImageUnsupportedText
return true
}
reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
return true
}
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
if !hasOpenAIImageGenerationTool(reqBody) {
return nil
@@ -658,12 +913,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
}
if !isCodexToolCallItemType(typ) {
ensureCopy()
delete(newItem, "call_id")
}
if codexInputItemRequiresName(typ) {
if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
name := firstNonEmptyString(m["tool_name"])
if name == "" {
if function, ok := m["function"].(map[string]any); ok {
name = firstNonEmptyString(function["name"])
}
}
if name == "" {
name = "tool"
}
ensureCopy()
newItem["name"] = name
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
}
filtered = append(filtered, newItem)
@@ -672,10 +945,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
func isCodexToolCallItemType(typ string) bool {
if typ == "" {
switch typ {
case "function_call",
"tool_call",
"local_shell_call",
"tool_search_call",
"custom_tool_call",
"mcp_tool_call",
"function_call_output",
"mcp_tool_call_output",
"custom_tool_call_output",
"tool_search_output":
return true
default:
return false
}
}
func codexInputItemRequiresName(typ string) bool {
switch strings.TrimSpace(typ) {
case "function_call", "custom_tool_call", "mcp_tool_call":
return true
default:
return false
}
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {

View File

@@ -92,6 +92,235 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "tool_search_output", first["type"])
require.Equal(t, "fc1", first["call_id"])
}
func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "fccustom", first["call_id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fcmcp", second["call_id"])
}
func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "ig_123", first["id"])
_, hasCallID := first["call_id"]
require.False(t, hasCallID)
second, ok := input[1].(map[string]any)
require.True(t, ok)
_, hasCallID = second["call_id"]
require.False(t, hasCallID)
}
func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call_output", item["type"])
require.Equal(t, "fc1", item["call_id"])
require.Equal(t, "ok", item["output"])
_, hasRole := item["role"]
require.False(t, hasRole)
}
func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": []any{"a", "b"}},
},
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
item, ok := input[0].(map[string]any)
require.True(t, ok)
content, ok := item["content"].([]any)
require.True(t, ok)
part, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, `["a","b"]`, part["text"])
}
func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "custom", "name": "shell"},
},
"tool_choice": map[string]any{"type": "custom"},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "custom", choice["type"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "message", "role": "user", "content": "run tool"},
map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
item, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call", item["type"])
require.Equal(t, "tool", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "shell", item["name"])
require.Equal(t, "fc1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": []any{
map[string]any{
"type": "mcp_tool_call",
"call_id": "call_abc",
"name": "remote_tool",
"arguments": "{}",
},
},
}
applyCodexOAuthTransform(reqBody, true, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "mcp_tool_call", item["type"])
require.Equal(t, "remote_tool", item["name"])
require.Equal(t, "fcabc", item["call_id"])
}
func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
require.True(t, codexInputItemRequiresName(typ), typ)
require.True(t, isCodexToolCallItemType(typ), typ)
}
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true保持 false。
@@ -261,6 +490,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
require.Equal(t, "png", tool["output_format"])
}
func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": "draw a cat",
}
modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
require.False(t, modified)
require.NotContains(t, reqBody, "tools")
}
func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
@@ -306,6 +546,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *
func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
@@ -325,6 +566,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin
require.False(t, modified)
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
},
}
modified := applyCodexImageGenerationBridgeInstructions(reqBody)
require.False(t, modified)
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
reqBody := map[string]any{
"instructions": "existing instructions",
@@ -338,6 +593,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te
require.Equal(t, "existing instructions", reqBody["instructions"])
}
func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "describe"},
map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
require.Contains(t, err.Error(), "does not support image input")
}
func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
},
},
},
}
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
require.Error(t, err)
}
func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "hello"},
},
},
},
}
require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
}
func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"instructions": "existing instructions",
"input": "hello",
}
result := applyCodexOAuthTransform(reqBody, true, false)
require.True(t, result.Modified)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Contains(t, instructions, "existing instructions")
require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
require.Contains(t, instructions, "does not support image generation")
require.Contains(t, instructions, "switch to a non-Spark Codex model")
require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
}
func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"instructions": "existing instructions",
"input": "hello",
}
applyCodexOAuthTransform(reqBody, true, false)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
}
func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-image-2",

View File

@@ -0,0 +1,135 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4", result.UpstreamModel)
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
}
func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 3,
Name: "openai-oauth-pass",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4", result.Model)
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String())
}

View File

@@ -0,0 +1,120 @@
package service
import (
"net/http"
"strconv"
"strings"
"time"
)
const (
// AccountTestModeDefault drives the standard /responses connection test.
AccountTestModeDefault = "default"
// AccountTestModeCompact drives the /responses/compact compact-probe test.
AccountTestModeCompact = "compact"
)
func normalizeAccountTestMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case AccountTestModeCompact:
return AccountTestModeCompact
default:
return AccountTestModeDefault
}
}
func createOpenAICompactProbePayload(model string) map[string]any {
return map[string]any{
"model": strings.TrimSpace(model),
"instructions": "You are a helpful coding assistant.",
"input": []any{
map[string]any{
"type": "message",
"role": "user",
"content": "Respond with OK.",
},
},
}
}
func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool {
switch status {
case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
return true
case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity:
lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body)))
if strings.Contains(lower, "compact") {
for _, keyword := range []string{
"unsupported",
"not support",
"does not support",
"not available",
"disabled",
} {
if strings.Contains(lower, keyword) {
return true
}
}
}
}
return false
}
func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any {
updates := map[string]any{
"openai_compact_checked_at": now.Format(time.RFC3339),
"openai_compact_last_status": nil,
}
if resp != nil {
updates["openai_compact_last_status"] = resp.StatusCode
}
switch {
case probeErr != nil:
updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048)
case resp == nil:
updates["openai_compact_last_error"] = "compact probe failed"
default:
errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if errMsg == "" && len(body) > 0 {
errMsg = strings.TrimSpace(string(body))
}
if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
errMsg = "HTTP " + strconv.Itoa(resp.StatusCode)
}
errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
updates["openai_compact_supported"] = true
updates["openai_compact_last_error"] = ""
} else {
if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) {
updates["openai_compact_supported"] = false
}
updates["openai_compact_last_error"] = errMsg
}
}
return updates
}
func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any {
if len(base) == 0 && len(more) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(more))
for key, value := range base {
out[key] = value
}
for key, value := range more {
out[key] = value
}
return out
}
func compactProbeSessionID(accountID int64) string {
if accountID <= 0 {
return "probe_compact"
}
return "probe_compact_" + strconv.FormatInt(accountID, 10)
}

View File

@@ -0,0 +1,122 @@
package service
import (
"errors"
"net/http"
"testing"
"time"
)
func TestNormalizeAccountTestMode(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "", want: AccountTestModeDefault},
{input: "default", want: AccountTestModeDefault},
{input: " compact ", want: AccountTestModeCompact},
{input: "COMPACT", want: AccountTestModeCompact},
{input: "unknown", want: AccountTestModeDefault},
}
for _, tt := range tests {
if got := normalizeAccountTestMode(tt.input); got != tt.want {
t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now)
if got := updates["openai_compact_supported"]; got != true {
t.Fatalf("openai_compact_supported = %v, want true", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusOK {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK)
}
if got := updates["openai_compact_last_error"]; got != "" {
t.Fatalf("openai_compact_last_error = %v, want empty string", got)
}
if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) {
t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339))
}
}
func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`404 page not found`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now)
if got := updates["openai_compact_supported"]; got != false {
t.Fatalf("openai_compact_supported = %v, want false", got)
}
if got := updates["openai_compact_last_status"]; got != http.StatusNotFound {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for 502 response")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for request error")
}
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got == "" {
t.Fatalf("expected openai_compact_last_error to be populated")
}
}
func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now)
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
}
if got := updates["openai_compact_last_error"]; got != "compact probe failed" {
t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now)
if _, exists := updates["openai_compact_supported"]; exists {
t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics")
}
if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest)
}
}
func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) {
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now)
if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable {
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable)
}
if got := updates["openai_compact_last_error"]; got != "HTTP 503" {
t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type errReadCloser struct {
err error
}
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
func (r errReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -1003,6 +1010,190 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{err: io.ErrUnexpectedEOF},
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 1,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
for i := 0; i < 6; i++ {
time.Sleep(250 * time.Millisecond)
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
}
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.Contains(t, rec.Body.String(), ":\n\n")
require.Contains(t, rec.Body.String(), "response.completed")
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr))
require.True(t, c.Writer.Written())
require.Contains(t, rec.Body.String(), "response.failed")
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1072,7 +1263,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1104,16 +1295,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
}
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1139,7 +1366,42 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 2, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)

View File

@@ -1,5 +1,7 @@
package service
import "strings"
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
@@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
return defaultMappedModel
}
return mappedModel
}
func isExplicitCodexModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
model = strings.ToLower(strings.TrimSpace(model))
if getNormalizedCodexModel(model) != "" {
return true
}
if strings.HasSuffix(model, "-openai-compact") {
base := strings.TrimSuffix(model, "-openai-compact")
return getNormalizedCodexModel(base) != ""
}
return false
}
// resolveOpenAICompactForwardModel determines the compact-only upstream model
// for /responses/compact requests. It never affects normal /responses traffic.
// When no compact-specific mapping matches, the input model is returned as-is.
func resolveOpenAICompactForwardModel(account *Account, model string) string {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" || account == nil {
return trimmedModel
}
mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel)
if !matched {
return trimmedModel
}
if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" {
return trimmedMapped
}
return trimmedModel
}

View File

@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
requestedModel: "claude-opus-4-6",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves explicit gpt-5.4 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves codex spark instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.3-codex-spark",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.3-codex-spark",
},
{
name: "preserves gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5",
},
{
name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "openai/gpt-5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "openai/gpt-5.5",
},
{
name: "preserves compact gpt-5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.5-openai-compact",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5-openai-compact",
},
}
for _, tt := range tests {
@@ -85,6 +130,74 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
}
}
func TestResolveOpenAICompactForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
model string
expectedModel string
}{
{
name: "nil account keeps original model",
account: nil,
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "missing compact mapping keeps original model",
account: &Account{
Credentials: map[string]any{},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
{
name: "exact compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
},
{
name: "wildcard compact mapping overrides model",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.*": "gpt-5-openai-compact",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5-openai-compact",
},
{
name: "passthrough compact mapping remains unchanged",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
model: "gpt-5.4",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel {
t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",

View File

@@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {

View File

@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 满足以下任一信号即视为续链previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
return true
}
}

View File

@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
{name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
{name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},

View File

@@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
@@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false)
require.NoError(t, err)
require.Nil(t, selection)
}
@@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
@@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)

View File

@@ -3800,6 +3800,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
requireCompact bool,
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
@@ -3840,11 +3841,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
if requireCompact && openAICompactSupportTier(account) == 0 {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -268,6 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
return err
}
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
@@ -281,6 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
return err
}
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
@@ -358,6 +365,142 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
return nil
}
if s.affiliateService == nil {
return nil
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
})
return fmt.Errorf("begin affiliate rebate tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("claim affiliate rebate audit: %w", err)
}
if !claimed {
return nil
}
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("accrue affiliate rebate: %w", err)
}
if rebateAmount <= 0 {
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
"baseAmount": o.Amount,
"reason": "no inviter bound or rebate amount <= 0",
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
return nil
}
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
"baseAmount": o.Amount,
"rebateAmount": rebateAmount,
}); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("update affiliate rebate applied audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
return nil
}
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
if client == nil {
return false, errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detail, _ := json.Marshal(map[string]any{
"baseAmount": baseAmount,
"status": "reserved",
})
rows, err := client.QueryContext(ctx, `
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
WHERE NOT EXISTS (
SELECT 1
FROM payment_audit_logs
WHERE order_id = $1::text
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
)
ON CONFLICT (order_id, action) DO NOTHING
RETURNING id`, oid, string(detail))
if err != nil {
return false, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return false, err
}
return false, nil
}
var claimID int64
if err := rows.Scan(&claimID); err != nil {
return false, err
}
return true, nil
}
func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
if client == nil {
return errors.New("nil payment client")
}
oid := strconv.FormatInt(orderID, 10)
detailJSON, _ := json.Marshal(detail)
updated, err := client.PaymentAuditLog.Update().
Where(
paymentauditlog.OrderIDEQ(oid),
paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
).
SetAction(action).
SetDetail(string(detailJSON)).
SetOperator("system").
Save(ctx)
if err != nil {
return err
}
if updated == 0 {
return errors.New("affiliate rebate claim log not found")
}
return nil
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)

View File

@@ -170,21 +170,22 @@ type TopUserStat struct {
// --- Service ---
type PaymentService struct {
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
affiliateService *AffiliateService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}

View File

@@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
@@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
return calculateOpenAI429ResetTime(headers)
}
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net/url"
"sort"
"strconv"
@@ -453,6 +454,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorEnabled,
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
SettingKeyAffiliateEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -540,6 +542,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
}, nil
}
@@ -686,6 +690,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
@@ -738,6 +743,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
}, nil
}
@@ -1167,6 +1173,26 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
if settings.AffiliateRebateFreezeHours < 0 {
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
}
if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
}
updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
if settings.AffiliateRebateDurationDays < 0 {
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
}
if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
}
updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
if settings.AffiliateRebatePerInviteeCap < 0 {
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
}
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
@@ -1202,6 +1228,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Available channels feature switch
updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled)
// Affiliate (邀请返利) feature switch
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -1477,6 +1506,78 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。
// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错,
// 调用方只关心一个可用的数值。
func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
if err != nil {
return AffiliateRebateRateDefault
}
rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) {
return AffiliateRebateRateDefault
}
return clampAffiliateRebateRate(rate)
}
// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
// 返回 0 表示不冻结(向后兼容)。
func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
if err != nil {
return AffiliateRebateFreezeHoursDefault
}
hours, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || hours < 0 {
return AffiliateRebateFreezeHoursDefault
}
if hours > AffiliateRebateFreezeHoursMax {
return AffiliateRebateFreezeHoursMax
}
return hours
}
// GetAffiliateRebateDurationDays 返回返利有效期(天)。
// 返回 0 表示永久有效。
func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
if err != nil {
return AffiliateRebateDurationDaysDefault
}
days, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || days < 0 {
return AffiliateRebateDurationDaysDefault
}
if days > AffiliateRebateDurationDaysMax {
return AffiliateRebateDurationDaysMax
}
return days
}
// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
// 返回 0 表示无上限。
func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
if err != nil {
return AffiliateRebatePerInviteeCapDefault
}
cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
return AffiliateRebatePerInviteeCapDefault
}
return cap
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -1719,6 +1820,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
@@ -1767,6 +1872,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Available channels feature (default disabled; opt-in)
SettingKeyAvailableChannelsEnabled: "false",
// Affiliate (邀请返利) feature (default disabled; opt-in)
SettingKeyAffiliateEnabled: "false",
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@@ -1846,6 +1954,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
} else {
result.AffiliateRebateRate = AffiliateRebateRateDefault
}
if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
if freezeHours > AffiliateRebateFreezeHoursMax {
freezeHours = AffiliateRebateFreezeHoursMax
}
result.AffiliateRebateFreezeHours = freezeHours
}
if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
if durationDays > AffiliateRebateDurationDaysMax {
durationDays = AffiliateRebateDurationDaysMax
}
result.AffiliateRebateDurationDays = durationDays
}
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
result.AffiliateRebatePerInviteeCap = perInviteeCap
}
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -2082,6 +2210,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Available channels feature (default: disabled; strict true)
result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true"
// Affiliate (邀请返利) feature (default: disabled; strict true)
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
@@ -2130,6 +2261,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result
}
func clampAffiliateRebateRate(value float64) float64 {
if math.IsNaN(value) || math.IsInf(value, 0) {
return AffiliateRebateRateDefault
}
if value < AffiliateRebateRateMin {
return AffiliateRebateRateMin
}
if value > AffiliateRebateRateMax {
return AffiliateRebateRateMax
}
return value
}
func isFalseSettingValue(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":

View File

@@ -104,10 +104,15 @@ type SystemSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
DefaultConcurrency int
DefaultBalance float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
DefaultConcurrency int
DefaultBalance float64
AffiliateEnabled bool
AffiliateRebateRate float64
AffiliateRebateFreezeHours int
AffiliateRebateDurationDays int
AffiliateRebatePerInviteeCap float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -224,6 +229,9 @@ type PublicSettings struct {
// Available Channels feature (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
// Affiliate (邀请返利) feature toggle
AffiliateEnabled bool `json:"affiliate_enabled"`
}
type WeChatConnectOAuthConfig struct {

View File

@@ -486,6 +486,7 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,
ProvidePaymentOrderExpiryService,

View File

@@ -0,0 +1,20 @@
CREATE TABLE IF NOT EXISTS user_affiliates (
user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
aff_code VARCHAR(32) NOT NULL UNIQUE,
inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
aff_count INTEGER NOT NULL DEFAULT 0,
aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';

View File

@@ -0,0 +1,58 @@
-- 1) Normalize historical affiliate rebate rate values.
-- Legacy compatibility treated 0<x<=1 as fractional inputs (e.g. 0.2 => 20%).
-- We now use pure percentage semantics, so convert persisted fractional values once.
UPDATE settings
SET value = to_char((value::numeric * 100), 'FM999999990.########'),
updated_at = NOW()
WHERE key = 'affiliate_rebate_rate'
AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
AND value::numeric > 0
AND value::numeric <= 1;
-- 2) Affiliate ledger for accrual/transfer traceability.
CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
action VARCHAR(32) NOT NULL,
amount DECIMAL(20,8) NOT NULL,
source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
-- 3) Enforce idempotency at DB layer for payment audit actions.
WITH ranked AS (
SELECT id,
ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
FROM payment_audit_logs
)
DELETE FROM payment_audit_logs p
USING ranked r
WHERE p.id = r.id
AND r.rn > 1;
CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
ON payment_audit_logs(order_id, action);
-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
SELECT po.id::text,
'AFFILIATE_REBATE_SKIPPED',
'{"reason":"baseline before affiliate rebate idempotency rollout"}',
'system',
NOW()
FROM payment_orders po
WHERE po.order_type = 'balance'
AND po.status = 'COMPLETED'
AND NOT EXISTS (
SELECT 1
FROM payment_audit_logs pal
WHERE pal.order_id = po.id::text
AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
);

View File

@@ -0,0 +1,16 @@
-- 邀请返利:用户专属配置增强
-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例百分比NULL 表示沿用全局比例)
-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选)
ALTER TABLE user_affiliates
ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2);
ALTER TABLE user_affiliates
ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false;
CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings
ON user_affiliates (updated_at)
WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL;
COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100NULL 表示沿用全局)';
COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)';

View File

@@ -0,0 +1,17 @@
-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
ALTER TABLE user_affiliates
ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
-- 3) Partial index for efficient thaw queries (only rows still frozen).
CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
ON user_affiliate_ledger (user_id, frozen_until)
WHERE frozen_until IS NOT NULL;

View File

@@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => {
})
})
it('posts affiliate code when completing linuxdo oauth registration', async () => {
const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
await completeLinuxDoOAuthRegistration(
'invite-code',
{
adoptDisplayName: true,
adoptAvatar: false
},
' AFF123 '
)
expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
invitation_code: 'invite-code',
aff_code: 'AFF123',
adopt_display_name: true,
adopt_avatar: false
})
})
it('posts oidc invitation completion with adoption decisions', async () => {
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
@@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => {
})
})
it('posts affiliate code when creating pending wechat oauth account', async () => {
const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
await createPendingWeChatOAuthAccount(
'invite-code',
{
adoptDisplayName: false,
adoptAvatar: true
},
'WXAFF'
)
expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
invitation_code: 'invite-code',
aff_code: 'WXAFF',
adopt_display_name: false,
adopt_avatar: true
})
})
it('classifies oauth completion results as login or bind', async () => {
const { getOAuthCompletionKind } = await import('@/api/auth')

View File

@@ -0,0 +1,108 @@
/**
* Admin Affiliate API endpoints
* Manage per-user affiliate (邀请返利) configurations:
* exclusive invite codes (overrides aff_code) and exclusive rebate rates.
*/
import { apiClient } from '../client'
import type { PaginatedResponse } from '@/types'
export interface AffiliateAdminEntry {
user_id: number
email: string
username: string
aff_code: string
aff_code_custom: boolean
aff_rebate_rate_percent?: number | null
aff_count: number
}
export interface ListAffiliateUsersParams {
page?: number
page_size?: number
search?: string
}
export interface UpdateAffiliateUserRequest {
aff_code?: string
aff_rebate_rate_percent?: number | null
/** Set true to explicitly clear the per-user rate (sets it to NULL). */
clear_rebate_rate?: boolean
}
export interface BatchSetRateRequest {
user_ids: number[]
aff_rebate_rate_percent?: number | null
/** Set true to clear rates instead of setting. */
clear?: boolean
}
export interface SimpleUser {
id: number
email: string
username: string
}
export async function listUsers(
params: ListAffiliateUsersParams = {},
): Promise<PaginatedResponse<AffiliateAdminEntry>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateAdminEntry>>(
'/admin/affiliates/users',
{
params: {
page: params.page ?? 1,
page_size: params.page_size ?? 20,
search: params.search ?? '',
},
},
)
return data
}
export async function lookupUsers(q: string): Promise<SimpleUser[]> {
const { data } = await apiClient.get<SimpleUser[]>(
'/admin/affiliates/users/lookup',
{ params: { q } },
)
return data
}
export async function updateUserSettings(
userId: number,
payload: UpdateAffiliateUserRequest,
): Promise<{ user_id: number }> {
const { data } = await apiClient.put<{ user_id: number }>(
`/admin/affiliates/users/${userId}`,
payload,
)
return data
}
export async function clearUserSettings(
userId: number,
): Promise<{ user_id: number }> {
const { data } = await apiClient.delete<{ user_id: number }>(
`/admin/affiliates/users/${userId}`,
)
return data
}
export async function batchSetRate(
payload: BatchSetRateRequest,
): Promise<{ affected: number }> {
const { data } = await apiClient.post<{ affected: number }>(
'/admin/affiliates/users/batch-rate',
payload,
)
return data
}
export const affiliatesAPI = {
listUsers,
lookupUsers,
updateUserSettings,
clearUserSettings,
batchSetRate,
}
export default affiliatesAPI

View File

@@ -29,6 +29,7 @@ import channelsAPI from './channels'
import channelMonitorAPI from './channelMonitor'
import channelMonitorTemplateAPI from './channelMonitorTemplate'
import adminPaymentAPI from './payment'
import affiliatesAPI from './affiliates'
/**
* Unified admin API object for convenient access
@@ -59,7 +60,8 @@ export const adminAPI = {
channels: channelsAPI,
channelMonitor: channelMonitorAPI,
channelMonitorTemplate: channelMonitorTemplateAPI,
payment: adminPaymentAPI
payment: adminPaymentAPI,
affiliates: affiliatesAPI
}
export {
@@ -88,7 +90,8 @@ export {
channelsAPI,
channelMonitorAPI,
channelMonitorTemplateAPI,
adminPaymentAPI
adminPaymentAPI,
affiliatesAPI
}
export default adminAPI

View File

@@ -308,6 +308,10 @@ export interface SystemSettings {
totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
default_balance: number;
affiliate_rebate_rate: number;
affiliate_rebate_freeze_hours: number;
affiliate_rebate_duration_days: number;
affiliate_rebate_per_invitee_cap: number;
default_concurrency: number;
default_user_rpm_limit: number;
default_subscriptions: DefaultSubscriptionSetting[];
@@ -477,6 +481,9 @@ export interface SystemSettings {
// Available Channels feature switch
available_channels_enabled: boolean;
// Affiliate (邀请返利) feature switch
affiliate_enabled: boolean;
}
export interface UpdateSettingsRequest {
@@ -489,6 +496,10 @@ export interface UpdateSettingsRequest {
invitation_code_enabled?: boolean;
totp_enabled?: boolean; // TOTP 双因素认证
default_balance?: number;
affiliate_rebate_rate?: number;
affiliate_rebate_freeze_hours?: number;
affiliate_rebate_duration_days?: number;
affiliate_rebate_per_invitee_cap?: number;
default_concurrency?: number;
default_user_rpm_limit?: number;
default_subscriptions?: DefaultSubscriptionSetting[];
@@ -634,6 +645,9 @@ export interface UpdateSettingsRequest {
// Available Channels feature switch
available_channels_enabled?: boolean;
// Affiliate (邀请返利) feature switch
affiliate_enabled?: boolean;
}
/**

Some files were not shown because too many files have changed in this diff Show More