Compare commits

...

52 Commits

Author SHA1 Message Date
shaw
8bf2a7b88a fix(scheduler): resolve SetSnapshot race conditions and remove usage throttle
Backend: Fix three race conditions in SetSnapshot that caused account
scheduling anomalies and broken sticky sessions:
- Use Lua CAS script for atomic version activation, preventing version
  rollback when concurrent goroutines write snapshots simultaneously
- Add UnlockBucket to release rebuild lock immediately after completion
  instead of waiting 30s TTL expiry
- Replace immediate DEL of old snapshots with 60s EXPIRE grace period,
  preventing readers from hitting empty ZRANGE during version switches

Frontend: Remove serial queue throttle (1-2s delay per request) from
usage loading since backend now uses passive sampling. All usage
requests execute immediately in parallel.
2026-04-29 22:48:39 +08:00
shaw
40feb86ba4 fix(httputil): add decompression bomb guard and fix errcheck lint 2026-04-29 22:11:45 +08:00
Wesley Liddick
f972a2faf2 Merge pull request #1990 from haha1903/feat/zstd-request-decompression
feat(httputil): decode zstd/gzip/deflate request bodies
2026-04-29 22:08:28 +08:00
Wesley Liddick
55a7fa1e07 Merge pull request #2005 from gaoren002/pr/openai-strip-passthrough-fields
fix(openai): strip unsupported passthrough fields
2026-04-29 21:46:19 +08:00
shaw
5e54d492be fix(lint): check type assertion error in codex transform test
The errcheck linter flagged an unchecked type assertion on
item["type"].(string). Use the two-value form with require.True
to satisfy the linter and fail clearly on unexpected types.
2026-04-29 21:35:18 +08:00
Wesley Liddick
8d6d31545f Merge pull request #2068 from Ogannesson/fix/openai-drop-reasoning-items-from-input
fix(openai): drop reasoning items from /v1/responses input on OAuth path
2026-04-29 21:32:52 +08:00
Wesley Liddick
17ced6b73a Merge pull request #2027 from hansnow/codex/fix-api-key-rate-limit-reset
fix(api-key): reset rate limit usage cache
2026-04-29 21:27:52 +08:00
Wesley Liddick
7f8f3fe0dd Merge pull request #2100 from KnowSky404/fix/codex-cli-edit-resend-tool-continuation
[codex] fix WS continuation inference for explicit tool replay
2026-04-29 21:14:55 +08:00
Wesley Liddick
46f06b2498 Merge pull request #2050 from zvensmoluya/fix/openai-compact-payload-fields
fix(openai): preserve current Codex compact payload fields
2026-04-29 21:03:48 +08:00
shaw
7ce5b83215 chore: remove superpowers docs 2026-04-29 21:00:30 +08:00
Wesley Liddick
27cad10d30 Merge pull request #2030 from KnowSky404/feature/account-bulk-edit-scope-and-compact
feat: support filtered account bulk edit and align compact OpenAI bulk fields
2026-04-29 20:56:43 +08:00
Wesley Liddick
ff6fa0203d Merge pull request #2058 from ivanvolt-labs/fix-responses-function-tool-choice
fix: use Responses-compatible function tool_choice format
2026-04-29 20:43:43 +08:00
KnowSky404
f7c13af11f fix: format ingress continuation test 2026-04-29 18:02:19 +08:00
KnowSky404
28dc34b6a3 fix(openai): avoid inferred WS continuation on explicit tool replay 2026-04-29 17:38:08 +08:00
Wesley Liddick
4d676dddd1 Merge pull request #2066 from alfadb/fix/anthropic-stream-eof-failover
fix(gateway): Anthropic 流式 EOF 失败移交 + SSE error 帧标准化
2026-04-29 17:09:47 +08:00
shaw
93d91e20b9 fix(vertex): audit fixes for Vertex Service Account feature (#1977)
- Security: force token_uri to Google default, preventing SSRF via crafted service account JSON
- Dedup: extract shared getVertexServiceAccountAccessToken() to eliminate ~35 lines of duplication between ClaudeTokenProvider and GeminiTokenProvider
- Fix: apply model mapping + Vertex model ID normalization in forward_as_responses and forward_as_chat_completions paths
- Fix: exclude service_account from AI Studio endpoint selection (Vertex cannot serve generativelanguage.googleapis.com)
- Feature: add model restriction/mapping UI for service_account in EditAccountModal
- Dedup: extract VERTEX_LOCATION_OPTIONS to shared constants
- i18n: replace all hardcoded Chinese strings in Vertex UI with translation keys
2026-04-29 16:53:09 +08:00
Wesley Liddick
63ef23108c Merge pull request #1977 from sholiverlee/vertex
feat: 支持 Vertex Service Account(Anthropic / Gemini)
2026-04-29 15:48:26 +08:00
alfadb
d78478e866 fix(gateway): sanitize stream errors to avoid leaking infrastructure topology
(*net.OpError).Error() concatenates Source/Addr fields, so the previous
disconnectMsg surfaced internal source IP/port and upstream server address
to clients via SSE error frames and UpstreamFailoverError.ResponseBody
(reported by @Wei-Shaw on PR #2066).

- Add sanitizeStreamError that maps known errors (io.ErrUnexpectedEOF,
  context.Canceled, syscall.ECONNRESET/EPIPE/ETIMEDOUT/...) to fixed
  descriptions and falls back to a generic placeholder, with an explicit
  *net.OpError branch that drops Source/Addr fields entirely.
- Use sanitized message in client-facing disconnectMsg; full ev.err is
  still preserved in the existing operator log line for diagnosis.
- Tests cover net.OpError redaction, the failover ResponseBody path, and
  every known sanitized error mapping.
2026-04-29 15:44:54 +08:00
Wesley Liddick
bf43fb4e38 Merge pull request #2044 from VitalyAnkh/fix/openai-image-apikey-versioned-base-url
fix(openai): honor versioned image base URLs
2026-04-29 15:20:14 +08:00
Wesley Liddick
a16c66500f Merge pull request #2090 from touwaeriol/feat/ops-retention-zero
feat(ops): allow retention days = 0 to wipe table on each scheduled cleanup
2026-04-29 15:12:30 +08:00
erio
4b6954f9f0 feat(ops): allow retention days = 0 to wipe table on each scheduled cleanup
Background / 背景

The ops cleanup task currently rejects retention days < 1 in both validate
and normalize, so operators who want minimal-history setups (e.g. high
churn deployments that prefer near-realtime cleanup) cannot express that
intent through the UI. The only options are 1+ days, which keeps at least
24h of history regardless of cron frequency.

ops 清理任务目前在 validate 和 normalize 两处都拒绝小于 1 的保留天数,
让希望尽量不留历史的运维场景(高吞吐部署 + 想用近实时清理)无法通过 UI
表达。最低只能配 1,等于不管 cron 多频繁,至少都会保留 24 小时的历史。

Purpose / 目的

Let admins set retention days to 0, meaning "every scheduled cleanup
run wipes the corresponding table(s) entirely". Combined with a more
frequent cron (e.g. `0 * * * *`) this yields effectively rolling cleanup.

允许管理员把保留天数设为 0,语义为"每次定时清理时把对应表全部清空"。
搭配更频繁的 cron(比如每小时整点)即可获得近似滚动清理的效果。

Changes / 改动内容

Backend

- service/ops_settings.go: validate accepts [0, 365]; normalize only
  refills default 30 when value is < 0 (negative is treated as legacy
  bad data, 0 is honoured)
- service/ops_cleanup_service.go: introduce `opsCleanupPlan(now, days)`
  returning `(cutoff, truncate, ok)`. days==0 returns truncate=true and
  short-circuits to a new `truncateOpsTable` helper that uses
  `TRUNCATE TABLE` (O(1), no WAL, no VACUUM pressure). days>0 keeps
  the existing batched DELETE path unchanged. Empty tables skip
  TRUNCATE to avoid the ACCESS EXCLUSIVE lock entirely
- Extract `isMissingRelationError` helper to dedupe the "table not
  yet created" tolerance shared by both delete and truncate paths
- Add unit tests for `opsCleanupPlan` (three branches) and
  `isMissingRelationError`

后端

- service/ops_settings.go: validate 接受 [0, 365];normalize 仅在 < 0
  时回填默认 30(负数视为脏数据,0 被尊重)
- service/ops_cleanup_service.go: 抽 `opsCleanupPlan(now, days)` 返回
  `(cutoff, truncate, ok)`。days==0 → truncate=true,走新增
  `truncateOpsTable`(TRUNCATE TABLE,O(1),无 WAL、无 VACUUM 压力);
  days>0 仍走原批量 DELETE 路径,行为完全不变。空表跳过 TRUNCATE,
  避免无意义的 ACCESS EXCLUSIVE 锁
- 抽 `isMissingRelationError` helper 复用 delete / truncate 两处的
  "表不存在"宽容判断
- 补 `opsCleanupPlan` 三分支 + `isMissingRelationError` 单元测试

Frontend

- OpsSettingsDialog.vue: validation accepts [0, 365]; input min=0
- i18n (zh/en): hint mentions "0 = wipe all on every cleanup",
  validation message updated to 0-365 range

前端

- OpsSettingsDialog.vue: 校验放宽到 [0, 365],input min 改 0
- i18n(zh/en):hint 补"0 = 每次清理时清空所有",错误提示改 0-365

Trade-offs / 取舍

- TRUNCATE requires ACCESS EXCLUSIVE lock briefly, but ops tables only
  have the cleanup task as a writer, so the lock is invisible to other
  workloads
- Empty-table guard avoids the lock when there is nothing to clean
- Negative values are still treated as legacy bad data and replaced
  with default 30 to preserve compatibility
2026-04-29 15:01:02 +08:00
shaw
da4b078df2 chore: update sponsors 2026-04-29 14:41:35 +08:00
Oganneson
7452fad820 fix(openai): drop reasoning items from /v1/responses input on OAuth path
Closes #1957

The OAuth path forwards client requests to chatgpt.com/backend-api/codex/responses,
where applyCodexOAuthTransform forces store=false (chatgpt.com's codex backend
rejects store=true). Reasoning items emitted under store=false are NEVER
persisted upstream, so any rs_* reference that a client carries forward in a
subsequent input[] array triggers a guaranteed upstream 404:

    Item with id 'rs_...' not found. Items are not persisted when `store` is
    set to false. Try again with `store` set to true, or remove this item
    from your input.

sub2api wraps this as 502 "Upstream request failed" and the conversation
breaks on every multi-turn /v1/responses request that uses reasoning + tools
(reproducible with gpt-5.5; gpt-5.4 happens to dodge it because the upstream
does not emit reasoning items for that model).

Affected clients include any that follow the OpenAI Responses API spec and
replay prior assistant items verbatim — in practice this hit OpenClaw and
similar agent harnesses on every turn ≥2 with tool use.

The fix: in filterCodexInput, drop input items with type == "reasoning"
entirely. The model never reads reasoning summary text from input (only
encrypted_content can carry reasoning context across turns, and chatgpt.com
under store=false does not emit it), so this is a no-op for the model itself
and a clean removal of unreachable upstream lookups.

Scope is intentionally narrow:
  * Only OAuth account requests (account.Type == AccountTypeOAuth) reach
    applyCodexOAuthTransform / filterCodexInput.
  * API-key accounts going to api.openai.com/v1/responses are unaffected
    (store=true works there, rs_* persists, multi-turn already works).
  * Anthropic / Gemini platform groups go through different transforms and
    are unaffected.
  * /v1/chat/completions is unaffected (no reasoning items).
  * item_reference items (different type) are unaffected — only type ==
    "reasoning" is dropped.

Verification:
  * Existing tests pass: go test ./internal/service/ -run Codex|Tool|OAuth
  * New regression test asserts reasoning items are dropped under both
    preserveReferences=true and preserveReferences=false.
  * End-to-end repro on gpt-5.5 multi-turn + tools: pre-patch 502, post-patch
    200. Repro on gpt-5.4 unchanged. Three-turn deep loop on gpt-5.5 passes.
2026-04-28 20:36:50 +08:00
alfadb
4c474616b9 fix(gateway): emit Anthropic-standard SSE error events and failover body
Two follow-ups to PR #2066's failover-wrap fix:

1. Failover ResponseBody (`UpstreamFailoverError.ResponseBody`) was encoded
   as `{"error": "<msg>"}` (string field). `ExtractUpstreamErrorMessage`
   probes for `error.message`, `detail`, or top-level `message` only — so
   `handleFailoverExhausted` and downstream passthrough rules saw an empty
   message, losing the EOF root cause in ops logs. Re-encode as the
   Anthropic standard shape `{"type":"error","error":{"type":"upstream_disconnected","message":"..."}}`.
   (Addresses the inline review comment from copilot-pull-request-reviewer
   on Wei-Shaw/sub2api#2066.)

2. The streaming `event: error` SSE frame for `response_too_large`,
   `stream_read_error`, and `stream_timeout` was non-standard
   (`{"error":"<reason>"}`). Anthropic SDKs (and Claude Code) expect
   `{"type":"error","error":{"type":"...","message":"..."}}` and parse
   `error.type`/`error.message` accordingly. Refactor `sendErrorEvent` to
   take both reason and message, and emit the standard frame so client
   SDKs surface a real diagnostic message instead of a generic stream error.

This does not by itself prevent task interruption on long-stream EOF
(SSE has no resume; client-side retry remains the only complete fix), but
it gives both server-side ops logs and client-side error UIs a meaningful
upstream message so users know the next step is to retry.

Tests updated to assert the new body shape on both branches plus a new
assertion that `ExtractUpstreamErrorMessage` returns a non-empty string.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 20:24:17 +08:00
alfadb
6327573534 fix(gateway): wrap Anthropic stream EOF as failover error before client output
Anthropic streaming path (gateway_service.go) returned a plain error on
upstream SSE read failure, so the handler-level UpstreamFailoverError check
never fired and the client received a bare `stream_read_error` event,
breaking long-running tasks even when no bytes had been written yet.

The most common trigger is HTTP/2 GOAWAY from api.anthropic.com edge
backends doing graceful rotation: Go's http.Transport surfaces this as
`unexpected EOF` and never auto-retries.

Mirror what the OpenAI and antigravity gateways already do: when the read
error happens before any byte has reached the client (`!c.Writer.Written()`),
return `*UpstreamFailoverError{StatusCode: 502, RetryableOnSameAccount: true}`
so the handler can retry on the same or another account. After client
output has begun, SSE has no resume protocol — keep the existing passthrough
behavior.

Tests cover both branches via streamReadCloser-based fixtures.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:12:48 +08:00
ivanvolt
04b2866f65 fix: use Responses-compatible function tool_choice format 2026-04-28 16:26:09 +08:00
Wesley Liddick
b0a2252ed1 Merge pull request #2051 from DaydreamCoding/openai-fast-flex-policy
feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
2026-04-28 12:14:43 +08:00
DaydreamCoding
30f55a1f72 feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。

后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
  OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
  × 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
  (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
  依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
  model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
  解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
  与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
  抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
  快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
  WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。

HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
  命中)、原生 Responses、Passthrough Responses 全部接入
  applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
  返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
  normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
  避免 chat/messages/native /responses/passthrough 因为 model 维度不同
  造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
  否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
  导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
  此前已具备同等行为)。

WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
  type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
  block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
  风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
  =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
  close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
  openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
  帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
  filter 闭包内同时侦测 session.update / session.created 帧的 session.model
  字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
  session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
  到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
  firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
  上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
  / WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
  error: {type: "forbidden_error", code: "policy_violation", message}},
  与 OpenAI codex 客户端 error event 解析兼容。

Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
  openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
  service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
  避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
  由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
  omitempty 行为。

测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
  模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
  block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
  fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
    helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
    error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
    帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
    pass 路径 fast 别名归一化回归 +
    ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
    error event 再收 close 1008 且上游 0 写)+
    passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
    建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
    passthrough session.update / session.created 旋转 capturedSessionModel
    的 mid-session 绕过回归 +
    passthrough billing post-filter ServiceTier 与 idempotent filter 回归。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 11:15:09 +08:00
Zven
3d4ca5e8d1 fix(openai): preserve current Codex compact payload fields 2026-04-28 10:55:29 +08:00
Oliver Li
0537a490f0 Merge branch 'Wei-Shaw:main' into vertex 2026-04-27 20:25:11 -04:00
VitalyR
ca5d029e7c fix(openai): honor versioned image base URLs 2026-04-28 04:53:29 +08:00
KnowSky404
1eca03432a fix: format bulk update account request 2026-04-27 18:36:05 +08:00
KnowSky404
53b24bc2d8 fix: tighten account bulk edit target typing 2026-04-27 18:20:36 +08:00
KnowSky404
a161f9d045 feat: align OpenAI bulk edit compact settings 2026-04-27 18:15:23 +08:00
KnowSky404
c5a1a82223 test: cover missing OpenAI bulk edit fields 2026-04-27 18:13:14 +08:00
KnowSky404
2ab6b34fd1 feat: add filtered-result account bulk edit 2026-04-27 18:12:24 +08:00
KnowSky404
764afbe37a test: cover account bulk edit target scopes 2026-04-27 18:08:22 +08:00
KnowSky404
25c7b0d9f4 feat: support filter-target account bulk update 2026-04-27 17:59:49 +08:00
KnowSky404
f422ac6dcc test: cover filter-target account bulk update 2026-04-27 17:32:34 +08:00
KnowSky404
54de4e008c docs: add account bulk edit implementation plan 2026-04-27 17:26:57 +08:00
KnowSky404
65c27d2c69 docs: add account bulk edit scope design 2026-04-27 17:21:11 +08:00
hansnow
53f919f8f0 fix(api-key): reset rate limit usage cache 2026-04-27 16:47:44 +08:00
Wesley Liddick
c92b88e34a Merge pull request #1996 from Cloud370/fix/claude-code-read-empty-pages
fix(anthropic): drop empty Read.pages in responses-to-anthropic tool input
2026-04-27 08:47:13 +08:00
Wesley Liddick
ed0c85a17e Merge pull request #2006 from gaoren002/pr/openai-images-explicit-session
fix(openai): avoid implicit image sticky sessions
2026-04-27 08:43:40 +08:00
gaoren002
9fe02bba7e fix(openai): strip unsupported passthrough fields 2026-04-27 00:39:06 +00:00
gaoren002
615557ec20 fix(openai): avoid implicit image sticky sessions 2026-04-26 17:09:41 +00:00
Oliver Li
3f05ef2ae3 Merge branch 'Wei-Shaw:main' into vertex 2026-04-26 08:39:41 -04:00
Cloud370
3022090365 fix(anthropic): drop empty Read.pages in responses-to-anthropic tool input 2026-04-26 20:21:38 +08:00
Hai Chang
798fd673e9 feat(httputil): decode compressed request bodies (zstd/gzip/deflate)
Codex CLI 0.125+ defaults to sending request bodies with
Content-Encoding: zstd. Without server-side decompression the gateway
returns 'Failed to parse request body' on /v1/responses (and any other
JSON endpoint) because gjson sees raw zstd bytes.

ReadRequestBodyWithPrealloc now inspects Content-Encoding and
transparently decodes zstd, gzip/x-gzip, and deflate bodies before
returning them, then strips the encoding headers and updates
ContentLength so downstream code can reuse the bytes safely.
Unsupported encodings produce a clear error.

Adds unit tests covering identity, zstd, gzip, deflate, unsupported
encoding, corrupt zstd payloads, nil bodies, and explicit identity.
2026-04-26 20:52:45 +10:00
github-actions[bot]
c056db740d chore: sync VERSION to 0.1.119 [skip ci] 2026-04-26 05:24:11 +00:00
Oliver
6d11f9ed77 Add Vertex service account support 2026-04-25 20:39:58 -04:00
Oliver
489a4d934e Show today stats for Vertex usage window 2026-04-25 19:46:32 -04:00
91 changed files with 7000 additions and 341 deletions

View File

@@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td> <td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
</tr> </tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
</tr>
</table> </table>
## Ecosystem ## Ecosystem

View File

@@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP支持一号一IP独享结合真实家庭网络与指纹隔离可实现链路环境隔离降低关联风控概率。</td> <td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP支持一号一IP独享结合真实家庭网络与指纹隔离可实现链路环境隔离降低关联风控概率。</td>
</tr> </tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>感谢 PatewayAI 赞助了本项目PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型100% 官方源直供不掺假不注水欢迎检验。计费透明Token 级账单可逐笔核验。
同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
</tr>
</table> </table>
## 生态项目 ## 生态项目

View File

@@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td> <td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
</tr> </tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>PatewayAI のご支援に感謝しますPatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
</tr>
</table> </table>
## エコシステム ## エコシステム

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

View File

@@ -1 +1 @@
0.1.118 0.1.119

View File

@@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db) userGroupRateRepository := repository.NewUserGroupRateRepository(db)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
@@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db) channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@@ -26,11 +26,12 @@ const (
// Account type constants // Account type constants
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分) AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分)
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI
) )
// Redeem type constants // Redeem type constants

View File

@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"` Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"` Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials" binding:"required"` Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"` Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts // BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct { type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"` AccountIDs []int64 `json:"account_ids"`
Name string `json:"name"` Filters *BulkUpdateAccountFilters `json:"filters"`
ProxyID *int64 `json:"proxy_id"` Name string `json:"name"`
Concurrency *int `json:"concurrency"` ProxyID *int64 `json:"proxy_id"`
Priority *int `json:"priority"` Concurrency *int `json:"concurrency"`
RateMultiplier *float64 `json:"rate_multiplier"` Priority *int `json:"priority"`
LoadFactor *int `json:"load_factor"` RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"` LoadFactor *int `json:"load_factor"`
Schedulable *bool `json:"schedulable"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"` Schedulable *bool `json:"schedulable"`
Credentials map[string]any `json:"credentials"` GroupIDs *[]int64 `json:"group_ids"`
Extra map[string]any `json:"extra"` Credentials map[string]any `json:"credentials"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
type BulkUpdateAccountFilters struct {
Platform string `json:"platform"`
Type string `json:"type"`
Status string `json:"status"`
Group string `json:"group"`
Search string `json:"search"`
PrivacyMode string `json:"privacy_mode"`
} }
// CheckMixedChannelRequest represents check mixed channel risk request // CheckMixedChannelRequest represents check mixed channel risk request
@@ -1369,6 +1379,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0") response.BadRequest(c, "rate_multiplier must be >= 0")
return return
} }
if len(req.AccountIDs) == 0 && req.Filters == nil {
response.BadRequest(c, "account_ids or filters is required")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断 // base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra) sanitizeExtraBaseRPM(req.Extra)
@@ -1394,6 +1408,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{ result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs, AccountIDs: req.AccountIDs,
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
Name: req.Name, Name: req.Name,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
@@ -1429,6 +1444,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
if filters == nil {
return nil
}
return &service.BulkUpdateAccountFilters{
Platform: filters.Platform,
Type: filters.Type,
Status: filters.Status,
Group: filters.Group,
Search: filters.Search,
PrivacyMode: filters.PrivacyMode,
}
}
// ========== OAuth Handlers ========== // ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL // GenerateAuthURLRequest represents the request for generating auth URL

View File

@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
require.Equal(t, float64(2), data["success"]) require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"]) require.Equal(t, float64(0), data["failed"])
} }
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"filters": map[string]any{
"platform": "openai",
"type": "oauth",
"status": "active",
"group": "12",
"privacy_mode": "blocked",
"search": "bulk-target",
},
"schedulable": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
}

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes)) require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes)) require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
} }
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
t.Run("nil input returns nil", func(t *testing.T) {
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
})
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.NotNil(t, out)
require.Len(t, out.Rules, 1)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
require.Equal(t, "all", out.Rules[0].ServiceTier)
})
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: " ",
Action: "pass",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
})
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "PRIORITY",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
})
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{
{ServiceTier: "priority", Action: "filter", Scope: "all"},
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Len(t, out.Rules, 3)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
})
}

View File

@@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
s.apiKeys[i].Usage5h = 0
s.apiKeys[i].Usage1d = 0
s.apiKeys[i].Usage7d = 0
s.apiKeys[i].Window5hStart = nil
s.apiKeys[i].Window1dStart = nil
s.apiKeys[i].Window7dStart = nil
k := s.apiKeys[i]
return &k, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil return nil
} }

View File

@@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
} }
} }
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group // AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
type AdminUpdateAPIKeyGroupRequest struct { type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
} }
// UpdateGroup handles updating an API key's group binding // UpdateGroup handles updating an API key's admin-managed fields.
// PUT /api/v1/admin/api-keys/:id // PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64) keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
return return
} }
var resetKey *service.APIKey
if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID) result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if resetKey != nil && req.GroupID == nil {
result.APIKey = resetKey
}
resp := struct { resp := struct {
APIKey *dto.APIKey `json:"api_key"` APIKey *dto.APIKey `json:"api_key"`

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
require.Nil(t, resp.Data.APIKey.GroupID) require.Nil(t, resp.Data.APIKey.GroupID)
} }
func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
svc := newStubAdminService()
now := time.Now()
svc.apiKeys[0].Usage5h = 1.2
svc.apiKeys[0].Usage1d = 3.4
svc.apiKeys[0].Usage7d = 5.6
svc.apiKeys[0].Window5hStart = &now
svc.apiKeys[0].Window1dStart = &now
svc.apiKeys[0].Window7dStart = &now
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5hStart *time.Time `json:"window_5h_start"`
Window1dStart *time.Time `json:"window_1d_start"`
Window7dStart *time.Time `json:"window_7d_start"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Zero(t, resp.Data.APIKey.Usage5h)
require.Zero(t, resp.Data.APIKey.Usage1d)
require.Zero(t, resp.Data.APIKey.Usage7d)
require.Nil(t, resp.Data.APIKey.Window5hStart)
require.Nil(t, resp.Data.APIKey.Window1dStart)
require.Nil(t, resp.Data.APIKey.Window7dStart)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) { func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{ svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(), stubAdminService: newStubAdminService(),

View File

@@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateEnabled: settings.AffiliateEnabled, AffiliateEnabled: settings.AffiliateEnabled,
} }
// OpenAI fast policy (stored under a dedicated setting key)
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
} }
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = dto.OpenAIFastPolicyRule(r)
}
return &dto.OpenAIFastPolicySettings{Rules: rules}
}
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
//
// 规范化 ServiceTier在 DTO 进入 service 层之前统一把空字符串归一为
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
// service.SetOpenAIFastPolicySettings 负责合法值校验。
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = service.OpenAIFastPolicyRule(r)
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
if tier == "" {
tier = service.OpenAIFastTierAny
}
rules[i].ServiceTier = tier
}
return &service.OpenAIFastPolicySettings{Rules: rules}
}
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct { type UpdateSettingsRequest struct {
// 注册设置 // 注册设置
@@ -452,6 +494,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch // Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"` AffiliateEnabled *bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
@@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
// Update OpenAI fast policy (stored under dedicated key, only when provided).
if req.OpenAIFastPolicySettings != nil {
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
response.BadRequest(c, err.Error())
return
}
}
// Update payment configuration (integrated into system settings). // Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe). // Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) { if h.paymentConfigService != nil && hasPaymentFields(req) {
@@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateEnabled: updatedSettings.AffiliateEnabled, AffiliateEnabled: updatedSettings.AffiliateEnabled,
} }
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
} }

View File

@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
} }
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call") if s.values != nil {
if value, ok := s.values[key]; ok {
return value, nil
}
}
return "", nil
} }
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {

View File

@@ -198,6 +198,9 @@ type SystemSettings struct {
// Affiliate (邀请返利) feature switch // Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
@@ -294,6 +297,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"` Rules []BetaPolicyRule `json:"rules"`
} }
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
ModelWhitelist []string `json:"model_whitelist,omitempty"`
FallbackAction string `json:"fallback_action,omitempty"`
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
}
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input. // Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem { func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil return true, nil
} }
func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
return nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil return nil, nil
} }

View File

@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return return
} }
sessionHash := "" sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
if parsed.Multipart {
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
} else {
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
}
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0

View File

@@ -258,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type) assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID) assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name) assert.Equal(t, "get_weather", anth.Content[1].Name)
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
}
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_read",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_read",
Name: "Read",
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.Equal(t, "tool_use", anth.Content[0].Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_other",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_other",
Name: "Search",
Arguments: `{"query":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
} }
func TestResponsesToAnthropic_Reasoning(t *testing.T) { func TestResponsesToAnthropic_Reasoning(t *testing.T) {
@@ -472,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason) assert.Equal(t, "tool_use", events[0].Delta.StopReason)
} }
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
}, state)
require.Len(t, events, 1)
assert.Equal(t, "content_block_start", events[0].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 0,
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
assert.Len(t, events, 0)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.done",
OutputIndex: 0,
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_delta", events[0].Type)
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
assert.Equal(t, "content_block_stop", events[1].Type)
}
func TestStreamingReasoning(t *testing.T) { func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
@@ -914,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
var tc map[string]any var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"]) assert.Equal(t, "function", tc["type"])
fn, ok := tc["function"].(map[string]any) assert.Equal(t, "get_weather", tc["name"])
require.True(t, ok) assert.NotContains(t, tc, "function")
assert.Equal(t, "get_weather", fn["name"]) }
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
}
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto" // {"type":"auto"} → "auto"
// {"type":"any"} → "required" // {"type":"any"} → "required"
// {"type":"none"} → "none" // {"type":"none"} → "none"
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} // {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct { var tc struct {
Type string `json:"type"` Type string `json:"type"`
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
return json.Marshal("none") return json.Marshal("none")
case "tool": case "tool":
return json.Marshal(map[string]any{ return json.Marshal(map[string]any{
"type": "function", "type": "function",
"function": map[string]string{"name": tc.Name}, "name": tc.Name,
}) })
default: default:
// Pass through unknown types as-is // Pass through unknown types as-is

View File

@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var tc map[string]any var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"]) assert.Equal(t, "function", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
assert.NotContains(t, tc, "function")
} }
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {

View File

@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
// //
// "auto" → "auto" // "auto" → "auto"
// "none" → "none" // "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}} // {"name":"X"} → {"type":"function","name":"X"}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is. // Try string first ("auto", "none", etc.) — pass through as-is.
var s string var s string
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
return nil, err return nil, err
} }
return json.Marshal(map[string]any{ return json.Marshal(map[string]any{
"type": "function", "type": "function",
"function": map[string]string{"name": obj.Name}, "name": obj.Name,
}) })
} }

View File

@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use", Type: "tool_use",
ID: fromResponsesCallID(item.CallID), ID: fromResponsesCallID(item.CallID),
Name: item.Name, Name: item.Name,
Input: json.RawMessage(item.Arguments), Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
}) })
case "web_search_call": case "web_search_call":
toolUseID := "srvtoolu_" + item.ID toolUseID := "srvtoolu_" + item.ID
@@ -129,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
} }
} }
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
if name != "Read" || raw == "" {
return json.RawMessage(raw)
}
var input map[string]json.RawMessage
if err := json.Unmarshal([]byte(raw), &input); err != nil {
return json.RawMessage(raw)
}
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
return json.RawMessage(raw)
}
delete(input, "pages")
sanitized, err := json.Marshal(input)
if err != nil {
return json.RawMessage(raw)
}
return sanitized
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter) // Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -142,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int ContentBlockIndex int
ContentBlockOpen bool ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use" CurrentBlockType string // "text" | "thinking" | "tool_use"
CurrentToolName string
CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int OutputIndexToBlockIdx map[int]int
@@ -181,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta": case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state) return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done": case "response.function_call_arguments.done":
return resToAnthHandleBlockDone(state) return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done": case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state) return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta": case "response.reasoning_summary_text.delta":
@@ -278,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use" state.CurrentBlockType = "tool_use"
state.CurrentToolName = evt.Item.Name
state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{ events = append(events, AnthropicStreamEvent{
Type: "content_block_start", Type: "content_block_start",
@@ -358,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil return nil
} }
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
state.CurrentToolArgs += evt.Delta
return nil
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok { if !ok {
return nil return nil
@@ -373,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}} }}
} }
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
return resToAnthHandleBlockDone(state)
}
raw := evt.Arguments
if raw == "" {
raw = state.CurrentToolArgs
}
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
if len(sanitized) == 0 {
return closeCurrentBlock(state)
}
idx := state.ContentBlockIndex
events := []AnthropicStreamEvent{{
Type: "content_block_delta",
Index: &idx,
Delta: &AnthropicDelta{
Type: "input_json_delta",
PartialJSON: string(sanitized),
},
}}
events = append(events, closeCurrentBlock(state)...)
return events
}
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" { if evt.Delta == "" {
return nil return nil
@@ -524,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex idx := state.ContentBlockIndex
state.ContentBlockOpen = false state.ContentBlockOpen = false
state.ContentBlockIndex++ state.ContentBlockIndex++
state.CurrentToolName = ""
state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{ return []AnthropicStreamEvent{{
Type: "content_block_stop", Type: "content_block_stop",
Index: &idx, Index: &idx,

View File

@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"} // "auto" → {"type":"auto"}
// "required" → {"type":"any"} // "required" → {"type":"any"}
// "none" → {"type":"none"} // "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // {"type":"function","name":"X"} → {"type":"tool","name":"X"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first // Try as string first
var s string var s string
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function // Try as object with type=function
var tc struct { var tc struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"`
Function struct { Function struct {
Name string `json:"name"` Name string `json:"name"`
} `json:"function"` } `json:"function"`
} }
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
name := strings.TrimSpace(tc.Name)
if name == "" {
name = strings.TrimSpace(tc.Function.Name)
}
if name == "" {
return raw, nil
}
return json.Marshal(map[string]string{ return json.Marshal(map[string]string{
"type": "tool", "type": "tool",
"name": tc.Function.Name, "name": name,
}) })
} }

View File

@@ -2,16 +2,28 @@ package httputil
import ( import (
"bytes" "bytes"
"compress/gzip"
"compress/zlib"
"errors"
"fmt"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/klauspost/compress/zstd"
) )
const ( const (
requestBodyReadInitCap = 512 requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20 requestBodyReadMaxInitCap = 1 << 20
// maxDecompressedBodySize limits the decompressed request body to 64 MB
// to prevent decompression bomb attacks.
maxDecompressedBodySize = 64 << 20
) )
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. // ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
// on content length, transparently decoding any Content-Encoding the upstream
// client used to compress the body (zstd, gzip, deflate).
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil { if req == nil || req.Body == nil {
return nil, nil return nil, nil
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if _, err := io.Copy(buf, req.Body); err != nil { if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil raw := buf.Bytes()
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
if enc == "" || enc == "identity" {
return raw, nil
}
decoded, err := decompressRequestBody(enc, raw)
if err != nil {
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
}
req.Header.Del("Content-Encoding")
req.Header.Del("Content-Length")
req.ContentLength = int64(len(decoded))
return decoded, nil
}
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
switch encoding {
case "zstd":
dec, err := zstd.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer dec.Close()
return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
case "gzip", "x-gzip":
gr, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer func() { _ = gr.Close() }()
return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
case "deflate":
zr, err := zlib.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer func() { _ = zr.Close() }()
return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
default:
return nil, errors.New("unsupported Content-Encoding")
}
} }

View File

@@ -0,0 +1,143 @@
package httputil
import (
"bytes"
"compress/gzip"
"compress/zlib"
"net/http"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
)
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
t.Helper()
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if encoding != "" {
req.Header.Set("Content-Encoding", encoding)
}
req.ContentLength = int64(len(body))
return req
}
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
enc, _ := zstd.NewWriter(nil)
compressed := enc.EncodeAll([]byte(samplePayload), nil)
_ = enc.Close()
req := newRequestWithBody(t, compressed, "zstd")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
if req.Header.Get("Content-Encoding") != "" {
t.Fatalf("Content-Encoding should be cleared after decoding")
}
if req.ContentLength != int64(len(samplePayload)) {
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
}
}
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := gw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("gzip write: %v", err)
}
if err := gw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "gzip")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
var buf bytes.Buffer
zw := zlib.NewWriter(&buf)
if _, err := zw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("zlib write: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("zlib close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "deflate")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "br")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for unsupported encoding, got nil")
}
if !strings.Contains(err.Error(), "br") {
t.Fatalf("error should mention encoding, got %v", err)
}
}
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for corrupt zstd body, got nil")
}
}
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Fatalf("expected nil body, got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "identity")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}

View File

@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
return true, nil return true, nil
} }
func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
return nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil return nil, nil
} }

View File

@@ -24,6 +24,49 @@ const (
defaultSchedulerSnapshotMGetChunkSize = 128 defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256 defaultSchedulerSnapshotWriteChunkSize = 256
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
// 替代立即 DEL让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
snapshotGraceTTLSeconds = 60
)
var (
// activateSnapshotScript 原子 CAS 切换快照版本。
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL避免与 reader 竞态。
//
// KEYS[1] = activeKey (sched:active:{bucket})
// KEYS[2] = readyKey (sched:ready:{bucket})
// KEYS[3] = bucketSetKey (sched:buckets)
// KEYS[4] = snapshotKey (新写入的快照 key)
// ARGV[1] = 新版本号字符串
// ARGV[2] = bucket 字符串 (用于 SADD)
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
// ARGV[4] = 宽限期 TTL 秒数
//
// 返回 1 = 已激活, 0 = 版本过旧未激活
activateSnapshotScript = redis.NewScript(`
local currentActive = redis.call('GET', KEYS[1])
local newVersion = tonumber(ARGV[1])
if currentActive ~= false then
local curVersion = tonumber(currentActive)
if curVersion and newVersion < curVersion then
redis.call('DEL', KEYS[4])
return 0
end
end
redis.call('SET', KEYS[1], ARGV[1])
redis.call('SET', KEYS[2], '1')
redis.call('SADD', KEYS[3], ARGV[2])
if currentActive ~= false and currentActive ~= ARGV[1] then
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
end
return 1
`)
) )
type schedulerCache struct { type schedulerCache struct {
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
} }
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) // Phase 1: 分配新版本号并写入快照数据。
oldActive, _ := c.rdb.Get(ctx, activeKey).Result() // INCR 保证每个调用方获得唯一递增版本号。
// 写入的 snapshotKey 是新的版本化 keyreader 尚不知晓,因此无竞态。
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket) versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result() version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil { if err != nil {
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
return err return err
} }
pipe := c.rdb.Pipeline()
if len(accounts) > 0 { if len(accounts) > 0 {
// 使用序号作为 score保持数据库返回的排序语义。 // 使用序号作为 score保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts)) members := make([]redis.Z, 0, len(accounts))
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10), Member: strconv.FormatInt(account.ID, 10),
}) })
} }
pipe := c.rdb.Pipeline()
for start := 0; start < len(members); start += c.writeChunkSize { for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize end := start + c.writeChunkSize
if end > len(members) { if end > len(members) {
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
} }
pipe.ZAdd(ctx, snapshotKey, members[start:end]...) pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
} }
} else { if _, err := pipe.Exec(ctx); err != nil {
pipe.Del(ctx, snapshotKey) return err
} }
pipe.Set(ctx, activeKey, versionStr, 0)
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
if _, err := pipe.Exec(ctx); err != nil {
return err
} }
if oldActive != "" && oldActive != versionStr { // Phase 2: 原子 CAS 激活版本。
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err() // Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
// 防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 宽限期而非立即 DEL避免 reader 竞态。
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
if err != nil {
return err
} }
return nil return nil
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result() return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
} }
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.Del(ctx, key).Err()
}
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result() raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil { if err != nil {

View File

@@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true, "payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false, "payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true, "openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"custom_menu_items": [], "custom_menu_items": [],
"custom_endpoints": [], "custom_endpoints": [],
"payment_enabled": false, "payment_enabled": false,
@@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": false, "payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false, "payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false, "openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"payment_enabled": false, "payment_enabled": false,
"payment_min_amount": 0, "payment_min_amount": 0,
"payment_max_amount": 0, "payment_max_amount": 0,

View File

@@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool {
type AccountTestService struct { type AccountTestService struct {
accountRepo AccountRepository accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config cfg *config.Config
@@ -74,6 +75,7 @@ type AccountTestService struct {
func NewAccountTestService( func NewAccountTestService(
accountRepo AccountRepository, accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider, geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config, cfg *config.Config,
@@ -82,6 +84,7 @@ func NewAccountTestService(
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider, geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg, cfg: cfg,
@@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if account.IsBedrock() { if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID) return s.testBedrockAccountConnection(c, ctx, account, testModelID)
} }
if account.Type == AccountTypeServiceAccount {
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL // Determine authentication method and API URL
var authToken string var authToken string
@@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body) return s.processClaudeStream(c, resp.Body)
} }
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
testModelID = mappedModel
} else {
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
}
if s.claudeTokenProvider == nil {
return s.sendErrorAndEnd(c, "Claude token provider not configured")
}
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
}
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusForbidden {
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, errMsg)
}
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke // testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account) region := bedrockRuntimeRegion(account)
@@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID = geminicli.DefaultTestModel testModelID = geminicli.DefaultTestModel
} }
// For API Key accounts with model mapping, map the model // For static upstream credentials with model mapping, map the model
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
@@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
case AccountTypeServiceAccount:
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
default: default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
@@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
} }
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("gemini token provider not configured")
}
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("failed to get service account access token: %w", err)
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
return req, nil
}
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) // buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
var inner map[string]any var inner map[string]any
@@ -1227,7 +1324,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
} }
apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations" apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
// Set SSE headers // Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true") require.Contains(t, rec.Body.String(), "\"success\":true")
} }
func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{},
}
account := &Account{
ID: 54,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}

View File

@@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@@ -58,6 +59,7 @@ type AdminService interface {
// API Key management (admin) // API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
@@ -291,6 +293,7 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts. // BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct { type BulkUpdateAccountsInput struct {
AccountIDs []int64 AccountIDs []int64
Filters *BulkUpdateAccountFilters
Name string Name string
ProxyID *int64 ProxyID *int64
Concurrency *int Concurrency *int
@@ -307,6 +310,15 @@ type BulkUpdateAccountsInput struct {
SkipMixedChannelCheck bool SkipMixedChannelCheck bool
} }
type BulkUpdateAccountFilters struct {
Platform string
Type string
Status string
Group string
Search string
PrivacyMode string
}
// BulkUpdateAccountResult captures the result for a single account update. // BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct { type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
@@ -1961,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil return result, nil
} }
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
apiKey.Usage5h = 0
apiKey.Usage1d = 0
apiKey.Usage7d = 0
apiKey.Window5hStart = nil
apiKey.Window1dStart = nil
apiKey.Window7dStart = nil
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
}
return apiKey, nil
}
// ReplaceUserGroup 替换用户的专属分组 // ReplaceUserGroup 替换用户的专属分组
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
if oldGroupID == newGroupID { if oldGroupID == newGroupID {
@@ -2286,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// BulkUpdateAccounts updates multiple accounts in one request. // BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object. // It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
if len(input.AccountIDs) == 0 && input.Filters != nil {
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
if err != nil {
return nil, err
}
input.AccountIDs = accountIDs
}
result := &BulkUpdateAccountsResult{ result := &BulkUpdateAccountsResult{
SuccessIDs: make([]int64, 0, len(input.AccountIDs)), SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)), FailedIDs: make([]int64, 0, len(input.AccountIDs)),
@@ -2401,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil return result, nil
} }
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
if filters == nil {
return nil, nil
}
groupID := int64(0)
switch strings.TrimSpace(filters.Group) {
case "":
case "ungrouped":
groupID = AccountListGroupUngrouped
default:
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid group filter: %w", err)
}
groupID = parsedGroupID
}
const pageSize = 500
page := 1
accountIDs := make([]int64, 0, pageSize)
for {
accounts, total, err := s.ListAccounts(
ctx,
page,
pageSize,
filters.Platform,
filters.Type,
filters.Status,
filters.Search,
groupID,
filters.PrivacyMode,
"",
"",
)
if err != nil {
return nil, err
}
for _, account := range accounts {
accountIDs = append(accountIDs, account.ID)
}
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
return accountIDs, nil
}
page++
}
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
if err := s.accountRepo.Delete(ctx, id); err != nil { if err := s.accountRepo.Delete(ctx, id); err != nil {
return err return err

View File

@@ -5,8 +5,10 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"reflect"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
getByIDCalled []int64 getByIDCalled []int64
listByGroupData map[int64][]Account listByGroupData map[int64][]Account
listByGroupErr map[int64]error listByGroupErr map[int64]error
listData []Account
listResult *pagination.PaginationResult
listErr error
listCalled bool
lastListParams pagination.PaginationParams
lastListFilters struct {
platform string
accountType string
status string
search string
groupID int64
privacyMode string
}
} }
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
return nil, nil return nil, nil
} }
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
s.listCalled = true
s.lastListParams = params
s.lastListFilters.platform = platform
s.lastListFilters.accountType = accountType
s.lastListFilters.status = status
s.lastListFilters.search = search
s.lastListFilters.groupID = groupID
s.lastListFilters.privacyMode = privacyMode
if s.listErr != nil {
return nil, nil, s.listErr
}
if s.listResult != nil {
return s.listData, s.listResult, nil
}
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{} repo := &accountRepoStubForBulkUpdate{}
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
// No BindGroups should have been called since the check runs before any write. // No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls) require.Empty(t, repo.bindGroupsCalls)
} }
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
listData: []Account{
{ID: 7},
{ID: 11},
},
listResult: &pagination.PaginationResult{Total: 2},
}
svc := &adminServiceImpl{accountRepo: repo}
schedulable := true
input := &BulkUpdateAccountsInput{
Schedulable: &schedulable,
}
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
filtersValue := reflect.New(filtersField.Type().Elem())
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
filtersValue.Elem().FieldByName("Group").SetString("12")
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
filtersField.Set(filtersValue)
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
require.Equal(t, StatusActive, repo.lastListFilters.status)
require.Equal(t, "bulk-target", repo.lastListFilters.search)
require.Equal(t, int64(12), repo.lastListFilters.groupID)
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
require.Equal(t, 2, result.Success)
require.Equal(t, 0, result.Failed)
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
}

View File

@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil return nil
} }
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
if s.cache == nil {
return nil
}
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
return err
}
return nil
}
// ============================================ // ============================================
// API Key 限速缓存方法 // API Key 限速缓存方法
// ============================================ // ============================================

View File

@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface. // ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider manages access_token for Claude OAuth accounts. // ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
type ClaudeTokenProvider struct { type ClaudeTokenProvider struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache ClaudeTokenCache tokenCache ClaudeTokenCache
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not an anthropic oauth account") return "", errors.New("not an anthropic oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
} }
cacheKey := ClaudeTokenCacheKey(account) cacheKey := ClaudeTokenCacheKey(account)
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil return accessToken, nil
} }
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}

View File

@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account") return "", errors.New("not an anthropic oauth or service account")
} }
cacheKey := ClaudeTokenCacheKey(account) cacheKey := ClaudeTokenCacheKey(account)
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account) token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account") require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token) require.Empty(t, token)
} }
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account) token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account") require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token) require.Empty(t, token)
} }
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account) token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account") require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token) require.Empty(t, token)
} }

View File

@@ -41,11 +41,12 @@ const (
// Account type constants // Account type constants
const ( const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分) AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分)
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI
) )
// Redeem type constants // Redeem type constants
@@ -306,6 +307,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules. // SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings" SettingKeyBetaPolicySettings = "beta_policy_settings"
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
// targets OpenAI's body-level service_tier field instead of Claude's
// anthropic-beta header.
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
// ========================= // =========================
// Claude Code Version Check // Claude Code Version Check
// ========================= // =========================

View File

@@ -0,0 +1,68 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer inbound-token")
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeServiceAccount,
Credentials: map[string]any{
"project_id": "vertex-proj",
"location": "us-east5",
},
}
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
svc := &GatewayService{}
req, err := svc.buildUpstreamRequest(
context.Background(),
c,
account,
body,
"vertex-token",
"service_account",
"claude-sonnet-4-5@20250929",
false,
false,
)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
got := readRequestBodyForTest(t, req)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
}
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
t.Helper()
require.NotNil(t, req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
return body
}

View File

@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping // 4. Model mapping
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel) normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel { if normalized != originalModel {
mappedModel = normalized mappedModel = normalized

View File

@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping // 4. Model mapping
mappedModel := originalModel mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel) normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel { if normalized != originalModel {
mappedModel = normalized mappedModel = normalized

View File

@@ -11,6 +11,7 @@ import (
"io" "io"
"log/slog" "log/slog"
mathrand "math/rand" mathrand "math/rand"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -20,6 +21,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
@@ -3597,7 +3599,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
} }
// OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID // OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel) if account.Type == AccountTypeServiceAccount {
requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
} else {
requestedModel = claude.NormalizeModelID(requestedModel)
}
} }
// 其他平台使用账户的模型支持检查 // 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel) return account.IsModelSupported(requestedModel)
@@ -3617,6 +3623,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
return apiKey, "apikey", nil return apiKey, "apikey", nil
case AccountTypeBedrock: case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key由 forwardBedrock 处理 return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key由 forwardBedrock 处理
case AccountTypeServiceAccount:
if account.Platform != PlatformAnthropic {
return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
}
if s.claudeTokenProvider == nil {
return "", "", errors.New("claude token provider not configured")
}
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "service_account", nil
default: default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type) return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
} }
@@ -4219,6 +4237,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
mappingSource = "account" mappingSource = "account"
} }
} }
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
if candidate, matched := account.ResolveMappedModel(reqModel); matched {
mappedModel = candidate
mappingSource = "account"
} else {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
if normalized != reqModel {
mappedModel = normalized
mappingSource = "vertex"
}
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel) normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel { if normalized != reqModel {
@@ -5688,6 +5718,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
}
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
@@ -5874,6 +5908,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
return req, nil return req, nil
} }
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
token string,
modelID string,
reqStream bool,
) (*http.Request, error) {
vertexBody, err := buildVertexAnthropicRequestBody(body)
if err != nil {
return nil, err
}
setOpsUpstreamRequestBody(c, vertexBody)
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
if err != nil {
return nil, err
}
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(strings.TrimSpace(key))
if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
continue
}
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
}
}
}
req.Header.Del("authorization")
req.Header.Del("x-api-key")
req.Header.Del("x-goog-api-key")
req.Header.Del("cookie")
req.Header.Del("anthropic-version")
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
setHeaderRaw(req.Header, "content-type", "application/json")
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
"url": req.URL.String(),
"token_type": "service_account",
"model": modelID,
"stream": strconv.FormatBool(reqStream),
})
return req, nil
}
// getBetaHeader 处理anthropic-beta header // getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20 // 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
@@ -6434,6 +6522,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false return false
} }
// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
func sanitizeStreamError(err error) string {
if err == nil {
return ""
}
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
return "unexpected EOF"
case errors.Is(err, io.EOF):
return "EOF"
case errors.Is(err, context.Canceled):
return "canceled"
case errors.Is(err, context.DeadlineExceeded):
return "deadline exceeded"
case errors.Is(err, syscall.ECONNRESET):
return "connection reset by peer"
case errors.Is(err, syscall.ECONNABORTED):
return "connection aborted"
case errors.Is(err, syscall.ETIMEDOUT):
return "connection timed out"
case errors.Is(err, syscall.EPIPE):
return "broken pipe"
case errors.Is(err, syscall.ECONNREFUSED):
return "connection refused"
}
var netErr *net.OpError
if errors.As(err, &netErr) {
if netErr.Timeout() {
if netErr.Op != "" {
return netErr.Op + " timeout"
}
return "i/o timeout"
}
if netErr.Op != "" {
return netErr.Op + " network error"
}
}
return "upstream connection error"
}
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 // ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} // 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string { func ExtractUpstreamErrorMessage(body []byte) string {
@@ -6871,14 +7002,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
lastDataAt := time.Now() lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":<reason>,"message":<message>}}
// 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析UI 能显示具体错误文案,
// 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
errorEventSent := false errorEventSent := false
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason, message string) {
if errorEventSent { if errorEventSent {
return return
} }
errorEventSent = true errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) if message == "" {
message = reason
}
body, err := json.Marshal(map[string]any{
"type": "error",
"error": map[string]string{
"type": reason,
"message": message,
},
})
if err != nil {
// json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
}
_, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
flusher.Flush() flusher.Flush()
} }
@@ -7038,10 +7186,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 客户端未断开,正常的错误处理 // 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
} }
sendErrorEvent("stream_read_error") // 上游中途读错误unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY
// 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
// 已经开始写流时 SSE 协议无 resume只能透传错误事件给客户端。
// 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
// 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
// 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
if !c.Writer.Written() {
logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
body, _ := json.Marshal(map[string]any{
"type": "error",
"error": map[string]string{
"type": "upstream_disconnected",
"message": disconnectMsg,
},
})
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: body,
RetryableOnSameAccount: true,
}
}
sendErrorEvent("stream_read_error", disconnectMsg)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
} }
line := ev.line line := ev.line
@@ -7100,7 +7270,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.rateLimitService != nil { if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
} }
sendErrorEvent("stream_timeout") sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh: case <-keepaliveCh:

View File

@@ -4,9 +4,12 @@ package service
import ( import (
"context" "context"
"errors"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"syscall"
"testing" "testing"
"time" "time"
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
body := rec.Body.String() body := rec.Body.String()
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
} }
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF发生在向客户端写入任何字节前
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError期望: %v", err)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
// ResponseBody 必须是 Anthropic 标准 error 格式:
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
// 2) error.type 标记为 upstream_disconnected
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message否则 ops 日志会丢失诊断信息")
require.Contains(t, extractedMsg, "upstream stream disconnected")
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
require.NotContains(t, rec.Body.String(), "stream_read_error")
}
// 上游已经发送过事件c.Writer 已写过字节)后再发生读错误:
// SSE 协议无 resume网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
err: io.ErrUnexpectedEOF,
},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
// 不应被错误地包成 failover error
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件error.type=stream_read_error
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
body := rec.Body.String()
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段Anthropic 标准)")
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因Claude Code 等客户端才能显示有效错误文案")
}
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
// failover ResponseBody 或 SSE error 帧返回给客户端。
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
require.NoError(t, err)
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
require.NoError(t, err)
raw := &net.OpError{
Op: "read",
Net: "tcp",
Source: src,
Addr: dst,
Err: syscall.ECONNRESET,
}
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
require.Contains(t, raw.Error(), "10.0.0.1")
require.Contains(t, raw.Error(), "52.1.2.3")
got := sanitizeStreamError(raw)
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
require.NotContains(t, got, "54321", "不得泄露源端口")
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
require.NotContains(t, got, "443", "不得泄露上游端口")
require.Equal(t, "connection reset by peer", got)
}
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
{"EOF", io.EOF, "EOF"},
{"context canceled", context.Canceled, "canceled"},
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
{"EPIPE", syscall.EPIPE, "broken pipe"},
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
{"nil 返回空串", nil, ""},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
})
}
}
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
// 时携带内部地址信息。
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
netErr := &net.OpError{
Op: "read",
Net: "tcp",
Source: src,
Addr: dst,
Err: syscall.ECONNRESET,
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{err: netErr},
}
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr))
body := string(failoverErr.ResponseBody)
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
require.NotContains(t, body, "54321")
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
require.NotContains(t, body, "443")
// 仍然包含可诊断的根因
require.Contains(t, body, "connection reset by peer")
require.Contains(t, body, "upstream stream disconnected")
}

View File

@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
} }
// Code Assist OAuth tokens often lack AI Studio scopes for models listing. // Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3 return 3
case AccountTypeServiceAccount:
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
return 999
default: default:
return 10 return 10
} }
@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model originalModel := req.Model
mappedModel := req.Model mappedModel := req.Model
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model) mappedModel = account.GetMappedModel(req.Model)
} }
@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
requestIDHeader = "x-request-id" requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
if err != nil {
return nil, "", err
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default: default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type) return nil, fmt.Errorf("unsupported account type: %s", account.Type)
} }
@@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
body = ensureGeminiFunctionCallThoughtSignatures(body) body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
@@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
requestIDHeader = "x-request-id" requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
if err != nil {
return nil, "", err
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default: default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type) return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
} }

View File

@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute geminiTokenCacheSkew = 5 * time.Minute
) )
// GeminiTokenProvider manages access_token for Gemini OAuth accounts. // GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct { type GeminiTokenProvider struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache GeminiTokenCache tokenCache GeminiTokenCache
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not a gemini oauth account") return "", errors.New("not a gemini oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
} }
cacheKey := GeminiTokenCacheKey(account) cacheKey := GeminiTokenCacheKey(account)
@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil return accessToken, nil
} }
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}
func GeminiTokenCacheKey(account *Account) string { func GeminiTokenCacheKey(account *Account) string {
if account != nil && account.Type == AccountTypeServiceAccount {
if key, err := parseVertexServiceAccountKey(account); err == nil {
return vertexServiceAccountCacheKey(account, key)
}
}
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return "gemini:" + projectID return "gemini:" + projectID

View File

@@ -53,6 +53,23 @@ const (
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>" codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
) )
var openAIChatGPTInternalUnsupportedFields = []string{
"user",
"metadata",
"prompt_cache_retention",
"safety_identifier",
"stream_options",
}
var openAICodexOAuthUnsupportedFields = append([]string{
"max_output_tokens",
"max_completion_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
}, openAIChatGPTInternalUnsupportedFields...)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
result := codexTransformResult{} result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。 // 工具续链需求会影响存储策略与 input 过滤逻辑。
@@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
} }
} }
// Strip parameters unsupported by codex models via the Responses API. // Strip parameters unsupported by ChatGPT internal Codex endpoint.
for _, key := range []string{ for _, key := range openAICodexOAuthUnsupportedFields {
"max_output_tokens",
"max_completion_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
// prompt_cache_retention is a newer Responses API parameter (cache TTL).
// The ChatGPT internal Codex endpoint rejects it with
// "Unsupported parameter: prompt_cache_retention". Defense-in-depth
// for any OAuth path that reaches this transform — the Cursor
// Responses-shape short-circuit in ForwardAsChatCompletions strips
// it earlier too, but we keep this line so other OAuth callers are
// equally protected.
"prompt_cache_retention",
} {
if _, ok := reqBody[key]; ok { if _, ok := reqBody[key]; ok {
delete(reqBody, key) delete(reqBody, key)
result.Modified = true result.Modified = true
@@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{ reqBody["tool_choice"] = map[string]any{
"type": "function", "type": "function",
"function": map[string]any{ "name": name,
"name": name,
},
} }
} }
} }
@@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
return false return false
} }
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { if choiceType == "" {
return false return false
} }
modified := false
if choiceType == "function" {
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
if name == "" {
if function, ok := choiceMap["function"].(map[string]any); ok {
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if name == "" {
reqBody["tool_choice"] = "auto"
return true
}
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
choiceMap["name"] = name
modified = true
}
if _, ok := choiceMap["function"]; ok {
delete(choiceMap, "function")
modified = true
}
if !codexToolsContainFunctionName(reqBody["tools"], name) {
reqBody["tool_choice"] = "auto"
return true
}
return modified
}
if codexToolsContainType(reqBody["tools"], choiceType) {
return modified
}
reqBody["tool_choice"] = "auto" reqBody["tool_choice"] = "auto"
return true return true
} }
@@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
return false return false
} }
func codexToolsContainFunctionName(rawTools any, name string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(name) == "" {
return false
}
normalizedName := strings.TrimSpace(name)
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
continue
}
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
if toolName == "" {
if function, ok := tool["function"].(map[string]any); ok {
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if toolName == normalizedName {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 { if len(input) == 0 {
return input, false return input, false
@@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
} }
typ, _ := m["type"].(string) typ, _ := m["type"].(string)
// chatgpt.com codex backend (OAuth path) does not persist reasoning
// items because applyCodexOAuthTransform forces store=false. Any rs_*
// reference replayed in input is guaranteed to 404 upstream
// ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
if typ == "reasoning" {
continue
}
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string { fixCallIDPrefix := func(id string) string {

View File

@@ -1,6 +1,8 @@
package service package service
import ( import (
"fmt"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
require.Equal(t, "custom", choice["type"]) require.Equal(t, "custom", choice["type"])
} }
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "shell"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "function", choice["type"])
require.Equal(t, "shell", choice["name"])
require.NotContains(t, choice, "function")
}
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "missing"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.4", "model": "gpt-5.4",
@@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
"prompt_cache_retention must be stripped before forwarding to Codex upstream") "prompt_cache_retention must be stripped before forwarding to Codex upstream")
} }
func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"user": "user_123",
"metadata": map[string]any{"trace_id": "abc"},
"prompt_cache_retention": "24h",
"safety_identifier": "sid",
"stream_options": map[string]any{"include_usage": true},
"input": []any{
map[string]any{"role": "user", "content": "hi"},
},
}
result := applyCodexOAuthTransform(reqBody, true, false)
require.True(t, result.Modified)
for _, field := range openAIChatGPTInternalUnsupportedFields {
require.NotContains(t, reqBody, field)
}
}
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) { func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
@@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
}) })
} }
} }
func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
// Reasoning items in input[] reference rs_* IDs that were emitted by
// chatgpt.com under store=false (forced by applyCodexOAuthTransform).
// They are never persisted upstream, so forwarding them produces a
// guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
// regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
build := func() []any {
return []any{
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
map[string]any{
"type": "reasoning",
"id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
"summary": []any{},
},
map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
}
}
for _, preserve := range []bool{true, false} {
preserve := preserve
t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
filtered := filterCodexInput(build(), preserve)
for _, raw := range filtered {
item, ok := raw.(map[string]any)
require.True(t, ok)
require.NotEqual(t, "reasoning", item["type"],
"reasoning items must be dropped from input on the OAuth path")
if id, ok := item["id"].(string); ok {
require.False(t, strings.HasPrefix(id, "rs_"),
"no item carrying an rs_* id should survive the filter")
}
}
// Sanity check: the non-reasoning items should still be present.
gotTypes := make(map[string]int)
for _, raw := range filtered {
item, ok := raw.(map[string]any)
require.True(t, ok)
typ, ok := item["type"].(string)
require.True(t, ok)
gotTypes[typ]++
}
require.Equal(t, 1, gotTypes["message"])
require.Equal(t, 1, gotTypes["function_call"])
require.Equal(t, 1, gotTypes["function_call_output"])
require.Equal(t, 0, gotTypes["reasoning"])
})
}
}

View File

@@ -0,0 +1,286 @@
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter默认策略覆盖所有模型
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filterservice_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tierauto/default/scale能透传到上游而不被
// 静默剥离。默认策略只针对 priority所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}

File diff suppressed because it is too large Load Diff

View File

@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
} }
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token // 5. Get access token
token, _, err := s.GetAccessToken(ctx, account) token, _, err := s.GetAccessToken(ctx, account)
if err != nil { if err != nil {

View File

@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
normalizeResponsesRequestServiceTier(req) normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier) require.Equal(t, "flex", req.ServiceTier)
// OpenAI 官方合法 tier 应被透传保留。
req.ServiceTier = "auto"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "auto", req.ServiceTier)
req.ServiceTier = "default" req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req) normalizeResponsesRequestServiceTier(req)
require.Equal(t, "default", req.ServiceTier)
req.ServiceTier = "scale"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "scale", req.ServiceTier)
// 真未知值仍被剥离。
req.ServiceTier = "turbo"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier) require.Empty(t, req.ServiceTier)
} }
@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Equal(t, "flex", tier) require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
require.NoError(t, err)
require.Equal(t, "auto", tier)
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "default", tier)
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
require.NoError(t, err)
require.Equal(t, "scale", tier)
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
// 真未知值才会被删除。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
require.NoError(t, err)
require.Empty(t, tier) require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists()) require.False(t, gjson.GetBytes(body, "service_tier").Exists())
} }

View File

@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
} }
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
// on the body-level service_tier field (priority/flex).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token // 5. Get access token
token, _, err := s.GetAccessToken(ctx, account) token, _, err := s.GetAccessToken(ctx, account)
if err != nil { if err != nil {

View File

@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil, nil,
nil, nil,
nil, nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got) require.Equal(t, "priority", *got)
}) })
t.Run("default ignored", func(t *testing.T) { t.Run("openai official tiers preserved", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("default")) // OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
got := normalizeOpenAIServiceTier(tier)
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
require.Equal(t, tier, *got)
}
}) })
t.Run("invalid ignored", func(t *testing.T) { t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo")) require.Nil(t, normalizeOpenAIServiceTier("turbo"))
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
}) })
} }
func TestExtractOpenAIServiceTier(t *testing.T) { func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil)) require.Nil(t, extractOpenAIServiceTier(nil))
} }
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) { func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil)) require.Nil(t, extractOpenAIServiceTierFromBody(nil))
} }

View File

@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
resolver *ModelPricingResolver resolver *ModelPricingResolver
channelService *ChannelService channelService *ChannelService
balanceNotifyService *BalanceNotifyService balanceNotifyService *BalanceNotifyService
settingService *SettingService
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
resolver *ModelPricingResolver, resolver *ModelPricingResolver,
channelService *ChannelService, channelService *ChannelService,
balanceNotifyService *BalanceNotifyService, balanceNotifyService *BalanceNotifyService,
settingService *SettingService,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
resolver: resolver, resolver: resolver,
channelService: channelService, channelService: channelService,
balanceNotifyService: balanceNotifyService, balanceNotifyService: balanceNotifyService,
settingService: settingService,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
@@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID return sessionID
} }
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
return sessionID
}
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
// client session signals. It intentionally skips content-derived fallback and is
// used by stateless endpoints such as /v1/images.
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
return ""
}
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
attachOpenAILegacySessionHashToGin(c, legacyHash)
return currentHash
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests. // GenerateSessionHash generates a sticky-session hash for OpenAI requests.
// //
// Priority: // Priority:
@@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return "" return ""
} }
sessionID := strings.TrimSpace(c.GetHeader("session_id")) sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" && len(body) > 0 { if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body) sessionID = deriveOpenAIContentSessionSeed(body)
} }
@@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch() disablePatch()
} }
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤)
// 针对 body 的 service_tier 字段("priority" 即 fast"flex"),按策略
// 执行 filter删除字段或 block拒绝请求。对 gpt-5.5 等模型屏蔽
// fast 时在此生效。
//
// 注意:
// 1. 此处统一使用 upstreamModel已经过 GetMappedModel +
// normalizeOpenAIModelForUpstream + Codex OAuth normalize
// chat-completions / messages 入口保持一致,避免不同入口因为模型
// 维度不同而出现 whitelist 命中差异。
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
// 行为,这里手工实现等效逻辑。
if rawTier, ok := reqBody["service_tier"].(string); ok {
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
}
blocked := &OpenAIFastBlockedError{Message: msg}
writeOpenAIFastPolicyBlockedResponse(c, blocked)
return nil, blocked
case BetaPolicyActionFilter:
delete(reqBody, "service_tier")
bodyModified = true
disablePatch()
default:
// pass若客户端传的是别名 "fast",归一化为 "priority"
// 后写回 body确保上游收到的是其能识别的规范值。
if normTier != rawTier {
reqBody["service_tier"] = normTier
bodyModified = true
markPatchSet("service_tier", normTier)
}
}
}
}
// Re-serialize body only if modified // Re-serialize body only if modified
if bodyModified { if bodyModified {
serializedByPatch := false serializedByPatch := false
@@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
body = sanitizedBody body = sanitizedBody
} }
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
// 统一使用 upstream 视角的 model透传路径下 body 已经过 compact 映射 +
// OAuth normalizebody 中的 model 字段即上游真正会看到的 slug。
// 这样可以与 chat-completions / messages / native /responses 入口的
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
// model 字段时退回 reqModel。
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
if policyModel == "" {
policyModel = reqModel
}
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeOpenAIFastPolicyBlockedResponse(c, blocked)
}
return nil, policyErr
}
body = updatedBody
logger.LegacyPrintf("service.openai_gateway", logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID, account.ID,
@@ -4841,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
} }
normalized := []byte(`{}`) normalized := []byte(`{}`)
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} { // Keep the current Codex /compact schema while still dropping request-scoped
// fields such as prompt_cache_key, store, and stream.
for _, field := range []string{
"model",
"input",
"instructions",
"tools",
"parallel_tool_calls",
"reasoning",
"text",
"previous_response_id",
} {
value := gjson.GetBytes(body, field) value := gjson.GetBytes(body, field)
if !value.Exists() { if !value.Exists() {
continue continue
@@ -5454,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
} }
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: // normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
// 1) store=false 2) 非 compact 保持 stream=truecompact 强制 stream=false // 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
// 2) store=false 3) 非 compact 保持 stream=truecompact 强制 stream=false
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) { func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
if len(body) == 0 { if len(body) == 0 {
return body, false, nil return body, false, nil
@@ -5463,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
normalized := body normalized := body
changed := false changed := false
for _, field := range openAIChatGPTInternalUnsupportedFields {
if value := gjson.GetBytes(normalized, field); !value.Exists() {
continue
}
next, err := sjson.DeleteBytes(normalized, field)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
}
normalized = next
changed = true
}
if compact { if compact {
if store := gjson.GetBytes(normalized, "store"); store.Exists() { if store := gjson.GetBytes(normalized, "store"); store.Exists() {
next, err := sjson.DeleteBytes(normalized, "store") next, err := sjson.DeleteBytes(normalized, "store")
@@ -5567,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
if value == "fast" { if value == "fast" {
value = "priority" value = "priority"
} }
// 放过 OpenAI 官方文档定义的所有合法 tier 值priority/flex/auto/default/scale。
// 对 Codex 客户端零影响Codex 只发 priority 或 flex见 codex-rs/core/src/client.rs
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
// 真未知值仍返回 nil由 normalizeResponsesBodyServiceTier 从 body 中删除。
switch value { switch value {
case "priority", "flex": case "priority", "flex", "auto", "default", "scale":
return &value return &value
default: default:
return nil return nil
} }
} }
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
type OpenAIFastBlockedError struct {
Message string
}
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
// evaluateOpenAIFastPolicy returns the action and error message that should be
// applied for a request with the given account/model/service_tier. When the
// policy service is unavailable or no rule matches, it returns
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
//
// Matching rules:
// - Scope filters by account type (all / oauth / apikey / bedrock)
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
// - ModelWhitelist narrows the rule to specific models; FallbackAction
// handles the non-matching case (default: pass)
//
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
// 规则可能针对不同 tokenfilter 需要累加成 setblock 则 first-match。
// - OpenAI fast policy 操作的是单个字段 service_tierfilter 即删字段,
// 没有可累加的对象。一次请求只携带一个 service_tier规则的 tier
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
// 发生重叠admin 可通过规则顺序明确意图。因此采用 first-match 而
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
if s == nil || s.settingService == nil {
return BetaPolicyActionPass, ""
}
tier := strings.ToLower(strings.TrimSpace(serviceTier))
if tier == "" {
return BetaPolicyActionPass, ""
}
settings := openAIFastPolicySettingsFromContext(ctx)
if settings == nil {
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
if err != nil || fetched == nil {
return BetaPolicyActionPass, ""
}
settings = fetched
}
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
}
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
// the settingService on every frame. See WSSession entry and
// openAIFastPolicySettingsFromContext for the caching glue.
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
if settings == nil {
return BetaPolicyActionPass, ""
}
isOAuth := account != nil && account.IsOAuth()
isBedrock := account != nil && account.IsBedrock()
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
continue
}
eff := BetaPolicyRule{
Action: rule.Action,
ErrorMessage: rule.ErrorMessage,
ModelWhitelist: rule.ModelWhitelist,
FallbackAction: rule.FallbackAction,
FallbackErrorMessage: rule.FallbackErrorMessage,
}
return resolveRuleAction(eff, model)
}
return BetaPolicyActionPass, ""
}
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
//
// Trade-off策略变更不会影响当前 WS session只影响新 session。这是
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
// 可以通过踢断 session 强制刷新。
type openAIFastPolicyCtxKeyType struct{}
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context供该 ctx
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
if ctx == nil || settings == nil {
return ctx
}
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
}
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
if ctx == nil {
return nil
}
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
return v
}
return nil
}
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
// body. When action=filter it removes the service_tier field; when
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
// rewriting the body so the upstream receives a slug it recognizes.
//
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
// 到了上游可识别值passthroughOpenAI 自动透传) / native /responses 等
// 入口没有这一前置步骤pass 路径下若不在此处归一化,"fast" 就会被原样
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
rawTier := gjson.GetBytes(body, "service_tier").String()
if rawTier == "" {
return body, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return body, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return body, &OpenAIFastBlockedError{Message: msg}
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(body, "service_tier")
if err != nil {
return body, fmt.Errorf("strip service_tier from body: %w", err)
}
return trimmed, nil
default:
// pass把别名如 "fast")写回为规范值("priority")。
if normTier == rawTier {
return body, nil
}
updated, err := sjson.SetBytes(body, "service_tier", normTier)
if err != nil {
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
}
return updated, nil
}
}
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
// request blocked by the OpenAI fast policy.
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
if c == nil || err == nil {
return
}
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": err.Message,
},
})
}
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
// against a single client→upstream WebSocket frame whose top-level
// "type"=="response.create". It mirrors the HTTP-side
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
// WS payload:
//
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
// - filter: returns a copy with top-level service_tier removed
// - block: returns (frame, *OpenAIFastBlockedError)
//
// Only frames whose "type" field strictly equals "response.create" are
// inspected/mutated. Any other frame type — including the empty string —
// passes through untouched. The OpenAI Realtime client-event spec requires
// "type" to be set, so an empty type is treated as a malformed frame we do
// not police; the upstream is the source of truth for rejecting it.
//
// service_tier lives at the top level of response.create — same as the
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
// to inspect / strip the top-level field; there is no nested form in the
// schema today.
//
// The caller is responsible for choosing the upstream model passed in —
// this helper does not re-derive it.
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
ctx context.Context,
account *Account,
model string,
frame []byte,
) ([]byte, *OpenAIFastBlockedError, error) {
if len(frame) == 0 {
return frame, nil, nil
}
if !gjson.ValidBytes(frame) {
return frame, nil, nil
}
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
// Strict match: only response.create is policy-checked. Empty / other
// types pass through untouched so we never accidentally strip fields
// from response.cancel, conversation.item.create, or any future
// client-event the spec adds. The Realtime spec requires "type" on
// every client event, so an empty type is malformed input — let the
// upstream reject it rather than guessing at our layer.
if frameType != "response.create" {
return frame, nil, nil
}
rawTier := gjson.GetBytes(frame, "service_tier").String()
if rawTier == "" {
return frame, nil, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return frame, nil, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return frame, &OpenAIFastBlockedError{Message: msg}, nil
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
if err != nil {
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
}
return trimmed, nil, nil
default:
return frame, nil, nil
}
}
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
// server-emitted error event. Matches the loose "evt_<rand>" convention used
// by upstream Realtime servers; the exact value is not load-bearing and is
// only required for client-side log correlation. We reuse the existing
// google/uuid dependency rather than pulling a new one.
func newOpenAIFastPolicyWSEventID() string {
id, err := uuid.NewRandom()
if err != nil {
// Extremely unlikely; fall back to a fixed prefix so the field is
// still non-empty and the schema stays self-consistent.
return "evt_openai_fast_policy"
}
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
// canonical form, mirroring what real Realtime traces look like.
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
}
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
// style "error" event payload for a request blocked by the OpenAI fast
// policy. The shape mirrors Realtime error events as observed in upstream
// traces and per the spec's server "error" event:
//
// {
// "event_id": "evt_<random>",
// "type": "error",
// "error": {
// "type": "invalid_request_error",
// "code": "policy_violation",
// "message": "..."
// }
// }
//
// event_id lets clients correlate the rejection in their logs; "code" gives
// programmatic clients a stable identifier (HTTP-side equivalent is the
// 403 permission_error JSON body).
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
if err == nil {
return nil
}
eventID := newOpenAIFastPolicyWSEventID()
payload, mErr := json.Marshal(map[string]any{
"event_id": eventID,
"type": "error",
"error": map[string]any{
"type": "invalid_request_error",
"code": "policy_violation",
"message": err.Message,
},
})
if mErr != nil {
// Fallback to a minimal hand-rolled payload; Marshal of the literal
// shape above should never fail in practice.
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
}
return payload
}
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) { func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) { if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil return body, false, nil

View File

@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
} }
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
t.Run("stateless image body stays unstuck", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("header overrides body", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
c.Request.Header.Set("session_id", "header-session")
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
})
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -1732,6 +1767,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
} }
} }
func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
normalized, changed, err := normalizeOpenAICompactRequestBody(body)
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
require.True(t, gjson.GetBytes(normalized, "tools").Exists())
require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
require.False(t, gjson.GetBytes(normalized, "store").Exists())
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
}
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) { func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@@ -11,6 +11,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
} }
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
require.Equal(t,
"https://image-upstream.example/v1/images/generations",
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
)
require.Equal(t,
"https://image-upstream.example/v1/images/edits",
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
)
require.Equal(t,
"https://image-upstream.example/v1/images/generations",
buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
)
require.Equal(t,
"https://image-upstream.example/v1/images/generations",
buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
)
}
type openAIImageTestSSEEvent struct { type openAIImageTestSSEEvent struct {
Name string Name string
Data string Data string
@@ -371,6 +391,124 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
} }
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req_img_apikey"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 6,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, "gpt-image-2", result.Model)
require.Equal(t, "gpt-image-2", result.UpstreamModel)
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
require.True(t, ok)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
require.NoError(t, writer.WriteField("prompt", "replace background"))
imagePart, err := writer.CreateFormFile("image", "source.png")
require.NoError(t, err)
_, err = imagePart.Write([]byte("png-image-content"))
require.NoError(t, err)
require.NoError(t, writer.Close())
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
req.Header.Set("Content-Type", writer.FormDataContentType())
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req_img_edit_apikey"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
require.NoError(t, err)
account := &Account{
ID: 7,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1/",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
require.True(t, ok)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
require.Contains(t, string(upstream.lastBody), `name="model"`)
require.Contains(t, string(upstream.lastBody), "gpt-image-2")
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)

View File

@@ -0,0 +1,33 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
require.NoError(t, err)
require.True(t, changed)
for _, field := range openAIChatGPTInternalUnsupportedFields {
require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
}
require.True(t, gjson.GetBytes(normalized, "stream").Bool())
require.False(t, gjson.GetBytes(normalized, "store").Bool())
}
func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
require.NoError(t, err)
require.True(t, changed)
require.False(t, gjson.GetBytes(normalized, "user").Exists())
require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
require.False(t, gjson.GetBytes(normalized, "store").Exists())
}

View File

@@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
func shouldInferIngressFunctionCallOutputPreviousResponseID( func shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled bool, storeDisabled bool,
turn int, turn int,
hasFunctionCallOutput bool, signals ToolContinuationSignals,
currentPreviousResponseID string, currentPreviousResponseID string,
expectedPreviousResponseID string, expectedPreviousResponseID string,
) bool { ) bool {
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
return false return false
} }
if strings.TrimSpace(currentPreviousResponseID) != "" { if strings.TrimSpace(currentPreviousResponseID) != "" {
return false return false
} }
if signals.HasFunctionCallOutputMissingCallID {
return false
}
// If the client already sent tool-call context or item_reference anchors,
// treat this as a full replay / self-contained continuation payload rather
// than downgrading it into an inferred delta continuation.
if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs {
return false
}
return strings.TrimSpace(expectedPreviousResponseID) != "" return strings.TrimSpace(expectedPreviousResponseID) != ""
} }
@@ -2366,6 +2375,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("token is empty") return errors.New("token is empty")
} }
// 预取一次 OpenAI Fast Policy settings绑定到 ctx让该 WS session
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
if s.settingService != nil {
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
ctx = withOpenAIFastPolicyContext(ctx, settings)
}
}
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool ingressMode := OpenAIWSIngressModeCtxPool
@@ -2524,6 +2542,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next normalized = next
} }
// Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
// single integration point for all WS ingress turns (first + follow-up
// frames flow through here).
//
// Model fallback: parseClientPayload above rejects any frame whose
// "model" field is missing (line ~2493-2500), so by the time we
// reach this point upstreamModel is always derived from a non-empty
// per-frame model. The capturedSessionModel fallback used in the
// passthrough adapter is therefore not needed in this path.
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
if policyErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
}
if blocked != nil {
// Send a Realtime-style error event to the client first, then
// signal the handler to close the connection with PolicyViolation.
// We intentionally do NOT forward this frame upstream.
//
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
// the underlying bufio writer before returning (write.go:42 →
// 307-311), and the subsequent close handshake re-acquires the
// same writeFrameMu, so the error event is guaranteed to reach
// the kernel send buffer before any close frame is queued.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
}
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
blocked.Message,
blocked,
)
}
normalized = policyApplied
return openAIWSClientPayload{ return openAIWSClientPayload{
payloadRaw: normalized, payloadRaw: normalized,
rawForHash: trimmed, rawForHash: trimmed,
@@ -3132,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID) expectedPrev := strings.TrimSpace(lastTurnResponseID)
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() toolSignals := ToolContinuationSignals{
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
}
if toolSignals.HasFunctionCallOutput {
var currentReqBody map[string]any
if err := json.Unmarshal(currentPayload, &currentReqBody); err == nil {
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
}
}
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
// store=false + function_call_output 场景必须有续链锚点。 // store=false + function_call_output 场景必须有续链锚点。
// 若客户端未传 previous_response_id优先回填上一轮响应 ID避免上游报 call_id 无法关联。 // 若客户端未传 previous_response_id优先回填上一轮响应 ID避免上游报 call_id 无法关联。
if shouldInferIngressFunctionCallOutputPreviousResponseID( if shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled, storeDisabled,
turn, turn,
hasFunctionCallOutput, toolSignals,
currentPreviousResponseID, currentPreviousResponseID,
expectedPrev, expectedPrev,
) { ) {

View File

@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
} }
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 114,
Name: "openai-ingress-tool-context",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 115,
Name: "openai-ingress-item-reference",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle prevPreflightPingIdle := openAIWSIngressPreflightPingIdle

View File

@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
name string name string
storeDisabled bool storeDisabled bool
turn int turn int
hasFunctionCallOutput bool signals ToolContinuationSignals
currentPreviousResponse string currentPreviousResponse string
expectedPrevious string expectedPrevious string
want bool want bool
}{ }{
{ {
name: "infer_when_all_conditions_match", name: "infer_when_all_conditions_match",
storeDisabled: true, storeDisabled: true,
turn: 2, turn: 2,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1", expectedPrevious: "resp_1",
want: true, want: true,
}, },
{ {
name: "skip_when_store_enabled", name: "skip_when_store_enabled",
storeDisabled: false, storeDisabled: false,
turn: 2, turn: 2,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1", expectedPrevious: "resp_1",
want: false, want: false,
}, },
{ {
name: "skip_on_first_turn", name: "skip_on_first_turn",
storeDisabled: true, storeDisabled: true,
turn: 1, turn: 1,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1", expectedPrevious: "resp_1",
want: false, want: false,
}, },
{ {
name: "skip_without_function_call_output", name: "skip_without_function_call_output",
storeDisabled: true, storeDisabled: true,
turn: 2, turn: 2,
hasFunctionCallOutput: false, signals: ToolContinuationSignals{},
expectedPrevious: "resp_1", expectedPrevious: "resp_1",
want: false, want: false,
}, },
{ {
name: "skip_when_request_already_has_previous_response_id", name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true, storeDisabled: true,
turn: 2, turn: 2,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
currentPreviousResponse: "resp_client", currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1", expectedPrevious: "resp_1",
want: false, want: false,
}, },
{ {
name: "skip_when_last_turn_response_id_missing", name: "skip_when_last_turn_response_id_missing",
storeDisabled: true, storeDisabled: true,
turn: 2, turn: 2,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "", expectedPrevious: "",
want: false, want: false,
}, },
{ {
name: "trim_whitespace_before_judgement", name: "trim_whitespace_before_judgement",
storeDisabled: true, storeDisabled: true,
turn: 2, turn: 2,
hasFunctionCallOutput: true, signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: " resp_2 ", expectedPrevious: " resp_2 ",
want: true, want: true,
},
{
name: "skip_when_tool_call_context_already_present",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
expectedPrevious: "resp_2",
want: false,
},
{
name: "skip_when_item_reference_already_covers_all_call_ids",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
expectedPrevious: "resp_2",
want: false,
},
{
name: "skip_when_function_call_output_missing_call_id",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
expectedPrevious: "resp_2",
want: false,
}, },
} }
@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
got := shouldInferIngressFunctionCallOutputPreviousResponseID( got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled, tt.storeDisabled,
tt.turn, tt.turn,
tt.hasFunctionCallOutput, tt.signals,
tt.currentPreviousResponse, tt.currentPreviousResponse,
tt.expectedPrevious, tt.expectedPrevious,
) )

View File

@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)

View File

@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn conn *coderws.Conn
} }
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty") return errors.New("token is empty")
} }
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough( logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage), len(firstClientMessage),
) )
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
// 与 WS ingressopenai_ws_forwarder.go:2991 取自 payload的语义一致。
//
// 多轮 passthroughOpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 service_tier。
// extractOpenAIServiceTierFromBody 返回 *string本身是指针类型
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account) wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil { if err != nil {
return fmt.Errorf("build ws url: %w", err) return fmt.Errorf("build ws url: %w", err)
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
} }
completedTurns := atomic.Int32{} completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用passthrough_relay.go: ReadFrame loop
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4opass→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing仅在成功non-block / non-err
// 的 response.create 帧上更新 requestServiceTierPtr使用
// filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response service_tier不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游billing tier 应保持
// 上一轮值。
// - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil这里有意
// 覆盖Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx, Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn}, ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn, UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage, FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{ Options: openaiwsv2.RelayOptions{
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens, CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
}, },
Model: turn.RequestModel, Model: turn.RequestModel,
ServiceTier: requestServiceTier, ServiceTier: requestServiceTierPtr.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
}, },
Model: relayResult.RequestModel, Model: relayResult.RequestModel,
ServiceTier: requestServiceTier, ServiceTier: requestServiceTierPtr.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),

View File

@@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
) )
} }
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
// - days < 0 → 跳过该项清理ok=false保留兼容老数据
// - days == 0 → TRUNCATE TABLEO(1) 全清truncate=true
// - days > 0 → 批量 DELETE 早于 now-N天 的行cutoff = now - N 天
//
// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE"
// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
// - 无 WAL 写入、无后续 VACUUM 压力
// - 这些 ops 表只有 cleanup 任务自己写TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
if days < 0 {
return time.Time{}, false, false
}
if days == 0 {
return time.Time{}, true, true
}
return now.AddDate(0, 0, -days), false, true
}
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) { func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
out := opsCleanupDeletedCounts{} out := opsCleanupDeletedCounts{}
if s == nil || s.db == nil || s.cfg == nil { if s == nil || s.db == nil || s.cfg == nil {
@@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
now := time.Now().UTC() now := time.Now().UTC()
// Error-like tables: error logs / retry attempts / alert events. // runOne 把"truncate? cutoff? batched delete?"封装到一处,
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 { // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
cutoff := now.AddDate(0, 0, -days) runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false) if truncate {
return truncateOpsTable(ctx, s.db, table)
}
return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
}
// Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
if err != nil { if err != nil {
return out, err return out, err
} }
out.errorLogs = n out.errorLogs = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false) n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
if err != nil { if err != nil {
return out, err return out, err
} }
out.retryAttempts = n out.retryAttempts = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false) n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
if err != nil { if err != nil {
return out, err return out, err
} }
out.alertEvents = n out.alertEvents = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
if err != nil { if err != nil {
return out, err return out, err
} }
out.systemLogs = n out.systemLogs = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
if err != nil { if err != nil {
return out, err return out, err
} }
@@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
} }
// Minute-level metrics snapshots. // Minute-level metrics snapshots.
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 { if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
cutoff := now.AddDate(0, 0, -days) n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
if err != nil { if err != nil {
return out, err return out, err
} }
@@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
} }
// Pre-aggregation tables (hourly/daily). // Pre-aggregation tables (hourly/daily).
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 { if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
cutoff := now.AddDate(0, 0, -days) n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
if err != nil { if err != nil {
return out, err return out, err
} }
out.hourlyPreagg = n out.hourlyPreagg = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true) n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
if err != nil { if err != nil {
return out, err return out, err
} }
@@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
res, err := db.ExecContext(ctx, q, cutoff, batchSize) res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil { if err != nil {
// If ops tables aren't present yet (partial deployments), treat as no-op. // If ops tables aren't present yet (partial deployments), treat as no-op.
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") { if isMissingRelationError(err) {
return total, nil return total, nil
} }
return total, err return total, err
@@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
return total, nil return total, nil
} }
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
//
// 与 deleteOldRowsByID 的差异:
// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
//
// 表不存在(部分部署)静默返回 0与 deleteOldRowsByID 保持一致。
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
if db == nil {
return 0, nil
}
var count int64
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("count %s: %w", table, err)
}
if count == 0 {
return 0, nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("truncate %s: %w", table, err)
}
return count, nil
}
// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
func isMissingRelationError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
}
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) { func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil { if s == nil {
return nil, false return nil, false

View File

@@ -0,0 +1,64 @@
package service
import (
"testing"
"time"
)
func TestOpsCleanupPlan(t *testing.T) {
now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
cases := []struct {
name string
days int
wantOK bool
wantTruncate bool
wantCutoff time.Time
}{
{name: "negative skips", days: -1, wantOK: false},
{name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
{name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
if ok != tc.wantOK {
t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
}
if !ok {
return
}
if truncate != tc.wantTruncate {
t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
}
if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
}
})
}
}
func TestIsMissingRelationError(t *testing.T) {
cases := []struct {
name string
err error
want bool
}{
{name: "nil is not missing", err: nil, want: false},
{name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
{name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
{name: "non-matching error", err: fakeErr("connection refused"), want: false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := isMissingRelationError(tc.err); got != tc.want {
t.Fatalf("got %v, want %v", got, tc.want)
}
})
}
}
type fakeErr string
func (e fakeErr) Error() string { return string(e) }

View File

@@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg.DataRetention.CleanupSchedule == "" { if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = "0 2 * * *" cfg.DataRetention.CleanupSchedule = "0 2 * * *"
} }
if cfg.DataRetention.ErrorLogRetentionDays <= 0 { // 保留天数0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
// 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
if cfg.DataRetention.ErrorLogRetentionDays < 0 {
cfg.DataRetention.ErrorLogRetentionDays = 30 cfg.DataRetention.ErrorLogRetentionDays = 30
} }
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 { if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
cfg.DataRetention.MinuteMetricsRetentionDays = 30 cfg.DataRetention.MinuteMetricsRetentionDays = 30
} }
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 { if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
cfg.DataRetention.HourlyMetricsRetentionDays = 30 cfg.DataRetention.HourlyMetricsRetentionDays = 30
} }
// Normalize auto refresh interval (default 30 seconds) // Normalize auto refresh interval (default 30 seconds)
@@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil { if cfg == nil {
return errors.New("invalid config") return errors.New("invalid config")
} }
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 { // 保留天数0 表示每次清理全部1-365 表示按天数保留。
return errors.New("error_log_retention_days must be between 1 and 365") if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
return errors.New("error_log_retention_days must be between 0 and 365")
} }
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 { if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
return errors.New("minute_metrics_retention_days must be between 1 and 365") return errors.New("minute_metrics_retention_days must be between 0 and 365")
} }
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 { if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
return errors.New("hourly_metrics_retention_days must be between 1 and 365") return errors.New("hourly_metrics_retention_days must be between 0 and 365")
} }
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 { if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
return errors.New("auto_refresh_interval_seconds must be between 15 and 300") return errors.New("auto_refresh_interval_seconds must be between 15 and 300")

View File

@@ -59,6 +59,8 @@ type SchedulerCache interface {
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
// TryLockBucket 尝试获取分桶重建锁。 // TryLockBucket 尝试获取分桶重建锁。
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
// UnlockBucket 释放分桶重建锁。
UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
// ListBuckets 返回已注册的分桶集合。 // ListBuckets 返回已注册的分桶集合。
ListBuckets(ctx context.Context) ([]SchedulerBucket, error) ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
// GetOutboxWatermark 读取 outbox 水位。 // GetOutboxWatermark 读取 outbox 水位。

View File

@@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
return true, nil return true, nil
} }
func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
return nil
}
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) { func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
return nil, nil return nil, nil
} }

View File

@@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
if !ok { if !ok {
return nil return nil
} }
defer func() {
_ = s.cache.UnlockBucket(ctx, bucket)
}()
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second) rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()

View File

@@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data)) return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
} }
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOpenAIFastPolicySettings(), nil
}
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
}
if value == "" {
return DefaultOpenAIFastPolicySettings(), nil
}
var settings OpenAIFastPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
// 行为时定位到 settings 表里的脏数据。
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
"error", err,
"key", SettingKeyOpenAIFastPolicySettings)
return DefaultOpenAIFastPolicySettings(), nil
}
return &settings, nil
}
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
}
validTiers := map[string]bool{
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
}
for i, rule := range settings.Rules {
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if tier == "" {
tier = OpenAIFastTierAny
}
if !validTiers[tier] {
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
}
settings.Rules[i].ServiceTier = tier
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
for j, pattern := range rule.ModelWhitelist {
trimmed := strings.TrimSpace(pattern)
if trimmed == "" {
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
}
settings.Rules[i].ModelWhitelist[j] = trimmed
}
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal openai fast policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置 // SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil { if settings == nil {

View File

@@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
}, },
} }
} }
// OpenAI Fast Policy 策略常量
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
// - "priority"(客户端可传 "fast",归一化为 "priority"fast 模式
// - "flex":低优先级模式
// - 省略normal 默认
//
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
// anthropic-beta header 换成 body 的 service_tier 字段。
const (
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
OpenAIFastTierPriority = "priority" // 仅匹配 fastpriority
OpenAIFastTierFlex = "flex" // 仅匹配 flex
)
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
}
// OpenAIFastPolicySettings OpenAI fast 策略配置
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
// 默认对所有模型的 priorityfast请求执行 filter即剔除 service_tier 字段,
// 让上游按 normal 优先级处理。
//
// 为什么 ModelWhitelist 为空(=对所有模型生效):
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
// 用户使用 gpt-4 + fastpriority 配额仍会被消耗。如果默认规则只锁
// gpt-5.5*"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
// 模型,可在 admin UI 中显式配置 model_whitelist。
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{
{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{},
FallbackAction: BetaPolicyActionPass,
},
},
}
}

View File

@@ -0,0 +1,345 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
vertexDefaultLocation = "us-central1"
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
vertexServiceAccountCacheSkew = 5 * time.Minute
vertexLockWaitTime = 200 * time.Millisecond
vertexAnthropicVersion = "vertex-2023-10-16"
)
var (
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
)
type vertexServiceAccountKey struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
TokenURI string `json:"token_uri"`
}
type vertexTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
func (a *Account) IsVertexServiceAccount() bool {
return a != nil && a.Type == AccountTypeServiceAccount
}
func (a *Account) VertexProjectID() string {
if a == nil {
return ""
}
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
return v
}
key, err := parseVertexServiceAccountKey(a)
if err == nil {
return strings.TrimSpace(key.ProjectID)
}
return ""
}
func (a *Account) VertexLocation(model string) string {
if a == nil {
return vertexDefaultLocation
}
if model != "" && a.Credentials != nil {
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
return strings.TrimSpace(loc)
}
}
}
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
return v
}
return vertexDefaultLocation
}
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
if account == nil || account.Credentials == nil {
return nil, errors.New("service account credentials not configured")
}
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
return nil, errors.New("service_account_json not found in credentials")
}
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
var key vertexServiceAccountKey
if err := json.Unmarshal(raw, &key); err != nil {
return nil, fmt.Errorf("invalid service account json: %w", err)
}
if strings.TrimSpace(key.ClientEmail) == "" {
return nil, errors.New("service account json missing client_email")
}
if strings.TrimSpace(key.PrivateKey) == "" {
return nil, errors.New("service account json missing private_key")
}
if strings.TrimSpace(key.ProjectID) == "" {
return nil, errors.New("service account json missing project_id")
}
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
key.TokenURI = vertexDefaultTokenURL
return &key, nil
}
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
fingerprint := ""
if key != nil {
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
fingerprint = hex.EncodeToString(sum[:8])
}
if fingerprint == "" && account != nil {
fingerprint = fmt.Sprintf("account:%d", account.ID)
}
return "vertex:service_account:" + fingerprint
}
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
// using the shared cache and distributed lock to avoid redundant exchanges.
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if cache != nil {
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if cache != nil {
var lockErr error
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(vertexLockWaitTime)
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if cache != nil {
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
now := time.Now()
claims := jwt.MapClaims{
"iss": key.ClientEmail,
"scope": vertexCloudPlatformScope,
"aud": key.TokenURI,
"iat": now.Unix(),
"exp": now.Add(time.Hour).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
if strings.TrimSpace(key.PrivateKeyID) != "" {
token.Header["kid"] = key.PrivateKeyID
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
if err != nil {
return "", 0, fmt.Errorf("parse service account private key: %w", err)
}
assertion, err := token.SignedString(privateKey)
if err != nil {
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
}
values := url.Values{}
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
values.Set("assertion", assertion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
if err != nil {
return "", 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", 0, fmt.Errorf("service account token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var parsed vertexTokenResponse
_ = json.Unmarshal(body, &parsed)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := strings.TrimSpace(parsed.ErrorDesc)
if msg == "" {
msg = strings.TrimSpace(parsed.Error)
}
if msg == "" {
msg = string(bytes.TrimSpace(body))
}
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
}
if strings.TrimSpace(parsed.AccessToken) == "" {
return "", 0, errors.New("service account token response missing access_token")
}
ttl := time.Duration(parsed.ExpiresIn) * time.Second
if ttl <= 0 {
ttl = time.Hour
}
if ttl > vertexServiceAccountCacheSkew {
ttl -= vertexServiceAccountCacheSkew
}
return parsed.AccessToken, ttl, nil
}
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
action = strings.TrimSpace(action)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
default:
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
u := fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
url.PathEscape(model),
action,
)
if stream {
u += "?alt=sse"
}
return u, nil
}
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
action := "rawPredict"
if stream {
action = "streamRawPredict"
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
return fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
escapedModel,
action,
), nil
}
func normalizeVertexAnthropicModelID(model string) string {
model = strings.TrimSpace(model)
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
return model
}
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
return m[1] + "@" + m[2]
}
return model
}
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
}
delete(payload, "model")
payload["anthropic_version"] = vertexAnthropicVersion
return json.Marshal(payload)
}

View File

@@ -0,0 +1,77 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildVertexGeminiURL(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexAnthropicURL(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
}
func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
}
func TestNormalizeVertexAnthropicModelID(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
}
func TestBuildVertexAnthropicRequestBody(t *testing.T) {
got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
require.NoError(t, err)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
}
func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
_, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid vertex location")
}
func TestParseVertexServiceAccountKey(t *testing.T) {
raw := `{
"type": "service_account",
"project_id": "vertex-proj",
"private_key_id": "kid",
"private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
"client_email": "svc@vertex-proj.iam.gserviceaccount.com"
}`
account := &Account{
Type: AccountTypeServiceAccount,
Platform: PlatformGemini,
Credentials: map[string]any{
"service_account_json": raw,
},
}
key, err := parseVertexServiceAccountKey(account)
require.NoError(t, err)
require.Equal(t, "vertex-proj", key.ProjectID)
require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
}

View File

@@ -404,12 +404,28 @@ func ProvideBillingCacheService(
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
} }
// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
func ProvideAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
billingCacheService *BillingCacheService,
) *APIKeyService {
svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
svc.SetRateLimitCacheInvalidator(billingCacheService)
return svc
}
// ProviderSet is the Wire provider set for all services // ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
// Core services // Core services
NewAuthService, NewAuthService,
NewUserService, NewUserService,
NewAPIKeyService, ProvideAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator, ProvideAPIKeyAuthCacheInvalidator,
NewGroupService, NewGroupService,
NewAccountService, NewAccountService,

View File

@@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: {
* @returns Success confirmation * @returns Success confirmation
*/ */
export async function bulkUpdate( export async function bulkUpdate(
accountIds: number[], accountIdsOrPayload: number[] | Record<string, unknown>,
updates: Record<string, unknown> updates?: Record<string, unknown>
): Promise<{ ): Promise<{
success: number success: number
failed: number failed: number
@@ -379,16 +379,19 @@ export async function bulkUpdate(
failed_ids?: number[] failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }> results: Array<{ account_id: number; success: boolean; error?: string }>
}> { }> {
const payload = Array.isArray(accountIdsOrPayload)
? {
account_ids: accountIdsOrPayload,
...(updates ?? {})
}
: accountIdsOrPayload
const { data } = await apiClient.post<{ const { data } = await apiClient.post<{
success: number success: number
failed: number failed: number
success_ids?: number[] success_ids?: number[]
failed_ids?: number[] failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }> results: Array<{ account_id: number; success: boolean; error?: string }>
}>('/admin/accounts/bulk-update', { }>('/admin/accounts/bulk-update', payload)
account_ids: accountIds,
...updates
})
return data return data
} }

View File

@@ -484,6 +484,9 @@ export interface SystemSettings {
// Affiliate (邀请返利) feature switch // Affiliate (邀请返利) feature switch
affiliate_enabled: boolean; affiliate_enabled: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
} }
export interface UpdateSettingsRequest { export interface UpdateSettingsRequest {
@@ -648,6 +651,9 @@ export interface UpdateSettingsRequest {
// Affiliate (邀请返利) feature switch // Affiliate (邀请返利) feature switch
affiliate_enabled?: boolean; affiliate_enabled?: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
} }
/** /**
@@ -875,6 +881,29 @@ export async function updateRectifierSettings(
return data; return data;
} }
// ==================== OpenAI Fast Policy Settings ====================
/**
* OpenAI fast/flex policy rule interface.
* Matches backend dto.OpenAIFastPolicyRule.
*/
export interface OpenAIFastPolicyRule {
service_tier: "all" | "priority" | "flex";
action: "pass" | "filter" | "block";
scope: "all" | "oauth" | "apikey" | "bedrock";
error_message?: string;
model_whitelist?: string[];
fallback_action?: "pass" | "filter" | "block";
fallback_error_message?: string;
}
/**
* OpenAI fast/flex policy settings interface.
*/
export interface OpenAIFastPolicySettings {
rules: OpenAIFastPolicyRule[];
}
// ==================== Beta Policy Settings ==================== // ==================== Beta Policy Settings ====================
/** /**

View File

@@ -332,6 +332,37 @@
<!-- Usage data or unlimited flow --> <!-- Usage data or unlimited flow -->
<div class="space-y-1"> <div class="space-y-1">
<div
v-if="showGeminiTodayStats && todayStats"
class="mb-0.5 flex items-center"
>
<div class="flex items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400">
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
{{ formatKeyRequests }} req
</span>
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
{{ formatKeyTokens }}
</span>
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800" :title="t('usage.accountBilled')">
A ${{ formatKeyCost }}
</span>
<span
v-if="todayStats.user_cost != null"
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
:title="t('usage.userBilled')"
>
U ${{ formatKeyUserCost }}
</span>
</div>
</div>
<div
v-else-if="showGeminiTodayStats && todayStatsLoading"
class="mb-0.5 flex items-center gap-1"
>
<div class="h-3 w-10 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div class="h-3 w-8 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div class="h-3 w-12 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
<div v-if="loading" class="space-y-1"> <div v-if="loading" class="space-y-1">
<div class="flex items-center gap-1"> <div class="flex items-center gap-1">
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div> <div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
@@ -512,6 +543,10 @@ const shouldFetchUsage = computed(() => {
return false return false
}) })
const showGeminiTodayStats = computed(() => {
return props.account.platform === 'gemini' && props.account.type === 'service_account'
})
const geminiUsageAvailable = computed(() => { const geminiUsageAvailable = computed(() => {
return ( return (
!!usageInfo.value?.gemini_shared_daily || !!usageInfo.value?.gemini_shared_daily ||

View File

@@ -17,7 +17,7 @@
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/> />
</svg> </svg>
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }} {{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}
</p> </p>
</div> </div>
@@ -27,7 +27,7 @@
<svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor"> <svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" /> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg> </svg>
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }} {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}
</p> </p>
</div> </div>
@@ -227,7 +227,7 @@
<ModelWhitelistSelector <ModelWhitelistSelector
v-model="allowedModels" v-model="allowedModels"
:platforms="selectedPlatforms" :platforms="targetSelectedPlatforms"
/> />
<p class="text-xs text-gray-500 dark:text-gray-400"> <p class="text-xs text-gray-500 dark:text-gray-400">
@@ -698,6 +698,87 @@
</div> </div>
</div> </div>
<!-- OpenAI OAuth Codex CLI only -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-codex-cli-only-label"
class="input-label mb-0"
for="bulk-edit-openai-codex-cli-only-enabled"
>
{{ t('admin.accounts.openai.codexCLIOnly') }}
</label>
<input
v-model="enableCodexCLIOnly"
id="bulk-edit-openai-codex-cli-only-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-codex-cli-only"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-codex-cli-only"
:class="!enableCodexCLIOnly && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.codexCLIOnlyDesc') }}
</p>
<button
id="bulk-edit-openai-codex-cli-only-toggle"
type="button"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
codexCLIOnlyEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
@click="codexCLIOnlyEnabled = !codexCLIOnlyEnabled"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
codexCLIOnlyEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<!-- OpenAI API Key WS mode -->
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-apikey-ws-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-apikey-ws-mode-enabled"
>
{{ t('admin.accounts.openai.wsMode') }}
</label>
<input
v-model="enableOpenAIAPIKeyWSMode"
id="bulk-edit-openai-apikey-ws-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-apikey-ws-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-apikey-ws-mode"
:class="!enableOpenAIAPIKeyWSMode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t(openAIAPIKeyWSModeConcurrencyHintKey) }}
</p>
<Select
v-model="openaiAPIKeyResponsesWebSocketV2Mode"
data-testid="bulk-edit-openai-apikey-ws-mode-select"
:options="openAIWSModeOptions"
aria-labelledby="bulk-edit-openai-apikey-ws-mode-label"
/>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) --> <!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
@@ -933,6 +1014,13 @@ interface Props {
accountIds: number[] accountIds: number[]
selectedPlatforms: AccountPlatform[] selectedPlatforms: AccountPlatform[]
selectedTypes: AccountType[] selectedTypes: AccountType[]
target?: {
mode: 'selected' | 'filtered'
filters?: Record<string, unknown>
previewCount?: number
selectedPlatforms?: AccountPlatform[]
selectedTypes?: AccountType[]
}
proxies: ProxyConfig[] proxies: ProxyConfig[]
groups: AdminGroup[] groups: AdminGroup[]
} }
@@ -947,40 +1035,53 @@ const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
// Platform awareness // Platform awareness
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1) const targetMode = computed(() => props.target?.mode ?? 'selected')
const targetPreviewCount = computed(() => props.target?.previewCount ?? props.accountIds.length)
const targetSelectedPlatforms = computed(() => props.target?.selectedPlatforms ?? props.selectedPlatforms)
const targetSelectedTypes = computed(() => props.target?.selectedTypes ?? props.selectedTypes)
const isMixedPlatform = computed(() => targetSelectedPlatforms.value.length > 1)
const allOpenAIPassthroughCapable = computed(() => { const allOpenAIPassthroughCapable = computed(() => {
return ( return (
props.selectedPlatforms.length === 1 && targetSelectedPlatforms.value.length === 1 &&
props.selectedPlatforms[0] === 'openai' && targetSelectedPlatforms.value[0] === 'openai' &&
props.selectedTypes.length > 0 && targetSelectedTypes.value.length > 0 &&
props.selectedTypes.every(t => t === 'oauth' || t === 'apikey') targetSelectedTypes.value.every(t => t === 'oauth' || t === 'apikey')
) )
}) })
const allOpenAIOAuth = computed(() => { const allOpenAIOAuth = computed(() => {
return ( return (
props.selectedPlatforms.length === 1 && targetSelectedPlatforms.value.length === 1 &&
props.selectedPlatforms[0] === 'openai' && targetSelectedPlatforms.value[0] === 'openai' &&
props.selectedTypes.length > 0 && targetSelectedTypes.value.length > 0 &&
props.selectedTypes.every(t => t === 'oauth') targetSelectedTypes.value.every(t => t === 'oauth')
)
})
const allOpenAIAPIKey = computed(() => {
return (
targetSelectedPlatforms.value.length === 1 &&
targetSelectedPlatforms.value[0] === 'openai' &&
targetSelectedTypes.value.length > 0 &&
targetSelectedTypes.value.every(t => t === 'apikey')
) )
}) })
// 是否全部为 Anthropic OAuth/SetupTokenRPM 配置仅在此条件下显示) // 是否全部为 Anthropic OAuth/SetupTokenRPM 配置仅在此条件下显示)
const allAnthropicOAuthOrSetupToken = computed(() => { const allAnthropicOAuthOrSetupToken = computed(() => {
return ( return (
props.selectedPlatforms.length === 1 && targetSelectedPlatforms.value.length === 1 &&
props.selectedPlatforms[0] === 'anthropic' && targetSelectedPlatforms.value[0] === 'anthropic' &&
props.selectedTypes.every(t => t === 'oauth' || t === 'setup-token') targetSelectedTypes.value.every(t => t === 'oauth' || t === 'setup-token')
) )
}) })
const filteredPresets = computed(() => { const filteredPresets = computed(() => {
if (props.selectedPlatforms.length === 0) return [] if (targetSelectedPlatforms.value.length === 0) return []
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>() const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
for (const platform of props.selectedPlatforms) { for (const platform of targetSelectedPlatforms.value) {
for (const preset of getPresetMappingsByPlatform(platform)) { for (const preset of getPresetMappingsByPlatform(platform)) {
const key = `${preset.from}=>${preset.to}` const key = `${preset.from}=>${preset.to}`
if (!dedupedPresets.has(key)) { if (!dedupedPresets.has(key)) {
@@ -1012,6 +1113,8 @@ const enableStatus = ref(false)
const enableGroups = ref(false) const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false) const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false) const enableOpenAIWSMode = ref(false)
const enableOpenAIAPIKeyWSMode = ref(false)
const enableCodexCLIOnly = ref(false)
const enableRpmLimit = ref(false) const enableRpmLimit = ref(false)
// State - field values // State - field values
@@ -1035,6 +1138,8 @@ const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([]) const groupIds = ref<number[]>([])
const openaiPassthroughEnabled = ref(false) const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF) const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const rpmLimitEnabled = ref(false) const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null) const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@@ -1076,6 +1181,9 @@ const openAIWSModeOptions = computed(() => [
const openAIWSModeConcurrencyHintKey = computed(() => const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value) resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
) )
const openAIAPIKeyWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiAPIKeyResponsesWebSocketV2Mode.value)
)
// Model mapping helpers // Model mapping helpers
const addModelMapping = () => { const addModelMapping = () => {
@@ -1254,6 +1362,19 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
) )
} }
if (enableOpenAIAPIKeyWSMode.value) {
const extra = ensureExtra()
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
openaiAPIKeyResponsesWebSocketV2Mode.value
)
}
if (enableCodexCLIOnly.value) {
const extra = ensureExtra()
extra.codex_cli_only = codexCLIOnlyEnabled.value
}
// RPM limit settings (写入 extra 字段) // RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) { if (enableRpmLimit.value) {
const extra = ensureExtra() const extra = ensureExtra()
@@ -1291,8 +1412,8 @@ const mixedChannelConfirmed = ref(false)
const canPreCheck = () => const canPreCheck = () =>
enableGroups.value && enableGroups.value &&
groupIds.value.length > 0 && groupIds.value.length > 0 &&
props.selectedPlatforms.length === 1 && targetSelectedPlatforms.value.length === 1 &&
(props.selectedPlatforms[0] === 'antigravity' || props.selectedPlatforms[0] === 'anthropic') (targetSelectedPlatforms.value[0] === 'antigravity' || targetSelectedPlatforms.value[0] === 'anthropic')
const handleClose = () => { const handleClose = () => {
showMixedChannelWarning.value = false showMixedChannelWarning.value = false
@@ -1309,7 +1430,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
try { try {
const result = await adminAPI.accounts.checkMixedChannelRisk({ const result = await adminAPI.accounts.checkMixedChannelRisk({
platform: props.selectedPlatforms[0], platform: targetSelectedPlatforms.value[0],
group_ids: groupIds.value group_ids: groupIds.value
}) })
if (!result.has_risk) return true if (!result.has_risk) return true
@@ -1325,7 +1446,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
} }
const handleSubmit = async () => { const handleSubmit = async () => {
if (props.accountIds.length === 0) { if (targetMode.value === 'selected' && props.accountIds.length === 0) {
appStore.showError(t('admin.accounts.bulkEdit.noSelection')) appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
return return
} }
@@ -1344,6 +1465,8 @@ const handleSubmit = async () => {
enableStatus.value || enableStatus.value ||
enableGroups.value || enableGroups.value ||
enableOpenAIWSMode.value || enableOpenAIWSMode.value ||
enableOpenAIAPIKeyWSMode.value ||
enableCodexCLIOnly.value ||
enableRpmLimit.value || enableRpmLimit.value ||
userMsgQueueMode.value !== null userMsgQueueMode.value !== null
@@ -1373,7 +1496,12 @@ const submitBulkUpdate = async (baseUpdates: Record<string, unknown>) => {
submitting.value = true submitting.value = true
try { try {
const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates) const res = targetMode.value === 'filtered' && props.target?.filters
? await adminAPI.accounts.bulkUpdate({
filters: props.target.filters,
...updates
})
: await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
const success = res.success || 0 const success = res.success || 0
const failed = res.failed || 0 const failed = res.failed || 0
@@ -1437,6 +1565,8 @@ watch(
enableGroups.value = false enableGroups.value = false
enableOpenAIPassthrough.value = false enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false enableOpenAIWSMode.value = false
enableOpenAIAPIKeyWSMode.value = false
enableCodexCLIOnly.value = false
enableRpmLimit.value = false enableRpmLimit.value = false
// Reset all values // Reset all values
@@ -1456,6 +1586,8 @@ watch(
status.value = 'active' status.value = 'active'
groupIds.value = [] groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
rpmLimitEnabled.value = false rpmLimitEnabled.value = false
bulkBaseRpm.value = null bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered' bulkRpmStrategy.value = 'tiered'

View File

@@ -153,7 +153,7 @@
<!-- Account Type Selection (Anthropic) --> <!-- Account Type Selection (Anthropic) -->
<div v-if="form.platform === 'anthropic'"> <div v-if="form.platform === 'anthropic'">
<label class="input-label">{{ t('admin.accounts.accountType') }}</label> <label class="input-label">{{ t('admin.accounts.accountType') }}</label>
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type"> <div class="mt-2 grid grid-cols-2 gap-3 sm:grid-cols-4" data-tour="account-form-type">
<button <button
type="button" type="button"
@click="accountCategory = 'oauth-based'" @click="accountCategory = 'oauth-based'"
@@ -244,6 +244,39 @@
</div> </div>
</button> </button>
<button
type="button"
@click="accountCategory = 'service_account'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'service_account'
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'service_account'
? 'bg-sky-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">Vertex</span>
<span class="text-xs text-gray-500 dark:text-gray-400">Service Account</span>
</div>
</button>
</div>
<div
v-if="accountCategory === 'service_account'"
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
>
<p>{{ t('admin.accounts.vertexAnthropicHint') }}</p>
</div> </div>
</div> </div>
@@ -302,6 +335,7 @@
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.responsesApi') }}</span> <span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.responsesApi') }}</span>
</div> </div>
</button> </button>
</div> </div>
</div> </div>
@@ -320,7 +354,7 @@
{{ t('admin.accounts.gemini.helpButton') }} {{ t('admin.accounts.gemini.helpButton') }}
</button> </button>
</div> </div>
<div class="mt-2 grid grid-cols-2 gap-3" data-tour="account-form-type"> <div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
<button <button
type="button" type="button"
@click="accountCategory = 'oauth-based'" @click="accountCategory = 'oauth-based'"
@@ -392,6 +426,36 @@
</span> </span>
</div> </div>
</button> </button>
<button
type="button"
@click="accountCategory = 'service_account'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'service_account'
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'service_account'
? 'bg-sky-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">
Vertex
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
Service Account
</span>
</div>
</button>
</div> </div>
<div <div
@@ -411,6 +475,13 @@
</div> </div>
</div> </div>
<div
v-if="accountCategory === 'service_account'"
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
>
<p>{{ t('admin.accounts.vertexGeminiHint') }}</p>
</div>
<!-- OAuth Type Selection (only show when oauth-based is selected) --> <!-- OAuth Type Selection (only show when oauth-based is selected) -->
<div v-if="accountCategory === 'oauth-based'" class="mt-4"> <div v-if="accountCategory === 'oauth-based'" class="mt-4">
<label class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</label> <label class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</label>
@@ -610,7 +681,7 @@
</div> </div>
<!-- Tier selection (used as fallback when auto-detection is unavailable/fails) --> <!-- Tier selection (used as fallback when auto-detection is unavailable/fails) -->
<div class="mt-4"> <div v-if="accountCategory !== 'service_account'" class="mt-4">
<label class="input-label">{{ t('admin.accounts.gemini.tier.label') }}</label> <label class="input-label">{{ t('admin.accounts.gemini.tier.label') }}</label>
<div class="mt-2"> <div class="mt-2">
<select <select
@@ -729,6 +800,96 @@
</div> </div>
</div> </div>
<!-- Vertex Service Account -->
<div v-if="(form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory === 'service_account'" class="space-y-4">
<div>
<label class="input-label">Service Account JSON</label>
<input
ref="vertexServiceAccountFileInput"
type="file"
accept="application/json,.json"
class="hidden"
@change="handleVertexServiceAccountFile"
/>
<div
:class="[
'rounded-lg border-2 border-dashed px-4 py-5 transition-colors',
vertexServiceAccountDragActive
? 'border-sky-500 bg-sky-50 dark:border-sky-500 dark:bg-sky-900/20'
: 'border-gray-300 bg-gray-50 hover:border-sky-400 hover:bg-sky-50/60 dark:border-dark-500 dark:bg-dark-700/40 dark:hover:border-sky-600 dark:hover:bg-sky-900/10'
]"
@dragenter.prevent="vertexServiceAccountDragActive = true"
@dragover.prevent="vertexServiceAccountDragActive = true"
@dragleave.prevent="vertexServiceAccountDragActive = false"
@drop.prevent="handleVertexServiceAccountDrop"
>
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
<div class="min-w-0">
<div class="flex items-center gap-2 text-sm font-medium text-gray-900 dark:text-white">
<Icon name="upload" size="sm" />
<span>{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}</span>
</div>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}
</p>
</div>
<button
type="button"
class="btn btn-secondary shrink-0"
@click="vertexServiceAccountFileInput?.click()"
>
<Icon name="upload" size="sm" />
{{ t('admin.accounts.vertexSaJsonSelectBtn') }}
</button>
</div>
<div
v-if="vertexClientEmail"
class="mt-3 rounded-md border border-sky-200 bg-white px-3 py-2 text-xs text-sky-900 dark:border-sky-800/50 dark:bg-dark-800 dark:text-sky-200"
>
<div class="truncate">Project ID: <span class="font-mono">{{ vertexProjectId }}</span></div>
<div class="truncate">Client Email: <span class="font-mono">{{ vertexClientEmail }}</span></div>
</div>
</div>
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonUploadHint') }}</p>
</div>
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div>
<label class="input-label">Project ID</label>
<input
v-model="vertexProjectId"
type="text"
class="input font-mono"
readonly
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
/>
</div>
<div>
<label class="input-label">Location</label>
<select
v-model="vertexLocation"
required
class="input font-mono"
>
<optgroup
v-for="group in VERTEX_LOCATION_OPTIONS"
:key="group.label"
:label="group.label"
>
<option
v-for="option in group.options"
:key="option.value"
:value="option.value"
>
{{ option.label }}
</option>
</optgroup>
</select>
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
</div>
</div>
</div>
<!-- Antigravity model restriction (applies to OAuth + Upstream) --> <!-- Antigravity model restriction (applies to OAuth + Upstream) -->
<!-- Antigravity 只支持模型映射模式不支持白名单模式 --> <!-- Antigravity 只支持模型映射模式不支持白名单模式 -->
<div v-if="form.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div v-if="form.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
@@ -2971,6 +3132,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import { import {
OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_OFF,
@@ -3085,7 +3247,7 @@ interface TempUnschedRuleForm {
// State // State
const step = ref(1) const step = ref(1)
const submitting = ref(false) const submitting = ref(false)
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token' const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyBaseUrl = ref('https://api.anthropic.com')
const apiKeyValue = ref('') const apiKeyValue = ref('')
@@ -3151,6 +3313,12 @@ const bedrockSessionToken = ref('')
const bedrockRegion = ref('us-east-1') const bedrockRegion = ref('us-east-1')
const bedrockForceGlobal = ref(false) const bedrockForceGlobal = ref(false)
const bedrockApiKeyValue = ref('') const bedrockApiKeyValue = ref('')
const vertexServiceAccountFileInput = ref<HTMLInputElement | null>(null)
const vertexServiceAccountJson = ref('')
const vertexProjectId = ref('')
const vertexClientEmail = ref('')
const vertexLocation = ref('global')
const vertexServiceAccountDragActive = ref(false)
const tempUnschedEnabled = ref(false) const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([]) const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping') const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
@@ -3397,7 +3565,7 @@ watch(
// Sync form.type based on accountCategory, addMethod, and platform-specific type // Sync form.type based on accountCategory, addMethod, and platform-specific type
watch( watch(
[accountCategory, addMethod, antigravityAccountType], [accountCategory, addMethod, antigravityAccountType, () => form.platform],
([category, method, agType]) => { ([category, method, agType]) => {
// Antigravity upstream 类型(实际创建为 apikey // Antigravity upstream 类型(实际创建为 apikey
if (form.platform === 'antigravity' && agType === 'upstream') { if (form.platform === 'antigravity' && agType === 'upstream') {
@@ -3409,7 +3577,9 @@ watch(
form.type = 'bedrock' as AccountType form.type = 'bedrock' as AccountType
return return
} }
if (category === 'oauth-based') { if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') {
form.type = 'service_account' as AccountType
} else if (category === 'oauth-based') {
form.type = method as AccountType // 'oauth' or 'setup-token' form.type = method as AccountType // 'oauth' or 'setup-token'
} else { } else {
form.type = 'apikey' form.type = 'apikey'
@@ -3447,6 +3617,12 @@ watch(
antigravityModelMappings.value = [] antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping' antigravityModelRestrictionMode.value = 'mapping'
} }
if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
accountCategory.value = 'oauth-based'
}
if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') {
accountCategory.value = 'oauth-based'
}
// Reset Bedrock fields when switching platforms // Reset Bedrock fields when switching platforms
bedrockAccessKeyId.value = '' bedrockAccessKeyId.value = ''
bedrockSecretAccessKey.value = '' bedrockSecretAccessKey.value = ''
@@ -3455,6 +3631,10 @@ watch(
bedrockForceGlobal.value = false bedrockForceGlobal.value = false
bedrockAuthMode.value = 'sigv4' bedrockAuthMode.value = 'sigv4'
bedrockApiKeyValue.value = '' bedrockApiKeyValue.value = ''
vertexServiceAccountJson.value = ''
vertexProjectId.value = ''
vertexClientEmail.value = ''
vertexLocation.value = 'global'
// Reset Anthropic/Antigravity-specific settings when switching to other platforms // Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false interceptWarmupRequests.value = false
@@ -3886,6 +4066,10 @@ const resetForm = () => {
antigravityAccountType.value = 'oauth' antigravityAccountType.value = 'oauth'
upstreamBaseUrl.value = '' upstreamBaseUrl.value = ''
upstreamApiKey.value = '' upstreamApiKey.value = ''
vertexServiceAccountJson.value = ''
vertexProjectId.value = ''
vertexClientEmail.value = ''
vertexLocation.value = 'global'
tempUnschedEnabled.value = false tempUnschedEnabled.value = false
tempUnschedRules.value = [] tempUnschedRules.value = []
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
@@ -4009,6 +4193,52 @@ const normalizePoolModeRetryCount = (value: number) => {
return normalized return normalized
} }
const applyVertexServiceAccountJson = (value: string) => {
const raw = value.trim()
if (!raw) {
vertexProjectId.value = ''
vertexClientEmail.value = ''
return false
}
try {
const parsed = JSON.parse(raw) as Record<string, unknown>
const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : ''
const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : ''
const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : ''
if (!projectId || !clientEmail || !privateKey) {
appStore.showError(t('admin.accounts.vertexSaJsonMissingFields'))
return false
}
vertexProjectId.value = projectId
vertexClientEmail.value = clientEmail
vertexServiceAccountJson.value = JSON.stringify(parsed)
return true
} catch {
appStore.showError(t('admin.accounts.vertexSaJsonInvalid'))
return false
}
}
const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value)
const handleVertexServiceAccountFile = async (event: Event) => {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
try {
applyVertexServiceAccountJson(await file.text())
} finally {
input.value = ''
}
}
const handleVertexServiceAccountDrop = async (event: DragEvent) => {
vertexServiceAccountDragActive.value = false
const file = event.dataTransfer?.files?.[0]
if (!file) return
applyVertexServiceAccountJson(await file.text())
}
const handleSubmit = async () => { const handleSubmit = async () => {
// For OAuth-based type, handle OAuth flow (goes to step 2) // For OAuth-based type, handle OAuth flow (goes to step 2)
if (isOAuthFlow.value) { if (isOAuthFlow.value) {
@@ -4122,6 +4352,29 @@ const handleSubmit = async () => {
return return
} }
if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') {
if (!form.name.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
if (!parseVertexServiceAccountJson()) {
return
}
if (!vertexLocation.value.trim()) {
appStore.showError(t('admin.accounts.vertexLocationRequired'))
return
}
const credentials: Record<string, unknown> = {
service_account_json: vertexServiceAccountJson.value.trim(),
project_id: vertexProjectId.value.trim(),
client_email: vertexClientEmail.value.trim(),
location: vertexLocation.value.trim(),
tier_id: 'vertex'
}
await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials)
return
}
// For apikey type, create directly // For apikey type, create directly
if (!apiKeyValue.value.trim()) { if (!apiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterApiKey')) appStore.showError(t('admin.accounts.pleaseEnterApiKey'))

View File

@@ -567,6 +567,221 @@
</div> </div>
</div> </div>
<!-- Vertex Service Account -->
<div v-if="(account.platform === 'gemini' || account.platform === 'anthropic') && account.type === 'service_account'" class="space-y-4">
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div>
<label class="input-label">Project ID</label>
<input
v-model="editVertexProjectId"
type="text"
class="input font-mono"
readonly
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
/>
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonEditHint') }}</p>
</div>
<div>
<label class="input-label">Location</label>
<select
v-model="editVertexLocation"
required
class="input font-mono"
>
<optgroup
v-for="group in VERTEX_LOCATION_OPTIONS"
:key="group.label"
:label="group.label"
>
<option
v-for="option in group.options"
:key="option.value"
:value="option.value"
>
{{ option.label }}
</option>
</optgroup>
</select>
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
</div>
</div>
<!-- Model Restriction Section for Service Account -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
<button
type="button"
@click="modelRestrictionMode = 'whitelist'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
<svg
class="mr-1.5 inline h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
@click="modelRestrictionMode = 'mapping'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
<svg
class="mr-1.5 inline h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M8 7h12m0 0l-4-4m4 4l-4 4m0 6H4m0 0l4 4m-4-4l4-4"
/>
</svg>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{
t('admin.accounts.supportsAllModels')
}}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else>
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
<p class="text-xs text-purple-700 dark:text-purple-400">
<svg
class="mr-1 inline h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.mapRequestModels') }}
</p>
</div>
<!-- Model Mapping List -->
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in modelMappings"
:key="getModelMappingKey(mapping)"
class="flex items-center gap-2"
>
<input
v-model="mapping.from"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.requestModel')"
/>
<svg
class="h-4 w-4 flex-shrink-0 text-gray-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M14 5l7 7m0 0l-7 7m7-7H3"
/>
</svg>
<input
v-model="mapping.to"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.actualModel')"
/>
<button
type="button"
@click="removeModelMapping(index)"
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
</div>
</div>
<button
type="button"
@click="addModelMapping"
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
>
<svg
class="mr-1 inline h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t('admin.accounts.addMapping') }}
</button>
<!-- Quick Add Buttons -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in presetMappings"
:key="preset.label"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) --> <!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
<div v-if="account.type === 'bedrock'" class="space-y-4"> <div v-if="account.type === 'bedrock'" class="space-y-4">
<!-- SigV4 fields --> <!-- SigV4 fields -->
@@ -1919,6 +2134,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import { import {
OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_OFF,
@@ -1987,6 +2203,9 @@ const editBedrockSessionToken = ref('')
const editBedrockRegion = ref('') const editBedrockRegion = ref('')
const editBedrockForceGlobal = ref(false) const editBedrockForceGlobal = ref(false)
const editBedrockApiKeyValue = ref('') const editBedrockApiKeyValue = ref('')
const editVertexProjectId = ref('')
const editVertexClientEmail = ref('')
const editVertexLocation = ref('us-central1')
const isBedrockAPIKeyMode = computed(() => const isBedrockAPIKeyMode = computed(() =>
props.account?.type === 'bedrock' && props.account?.type === 'bedrock' &&
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey' (props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
@@ -2246,6 +2465,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
const credentials = newAccount.credentials as Record<string, unknown> | undefined const credentials = newAccount.credentials as Record<string, unknown> | undefined
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
editVertexProjectId.value = ''
editVertexClientEmail.value = ''
editVertexLocation.value = 'us-central1'
// Load mixed scheduling setting (only for antigravity accounts) // Load mixed scheduling setting (only for antigravity accounts)
mixedScheduling.value = false mixedScheduling.value = false
@@ -2467,6 +2689,31 @@ const syncFormFromAccount = (newAccount: Account | null) => {
} else if (newAccount.type === 'upstream' && newAccount.credentials) { } else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown> const credentials = newAccount.credentials as Record<string, unknown>
editBaseUrl.value = (credentials.base_url as string) || '' editBaseUrl.value = (credentials.base_url as string) || ''
} else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown>
editVertexProjectId.value = (credentials.project_id as string) || ''
editVertexClientEmail.value = (credentials.client_email as string) || ''
editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
// Load model mappings for service_account
const existingMappings = credentials.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object') {
const entries = Object.entries(existingMappings)
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
if (isWhitelistMode) {
modelRestrictionMode.value = 'whitelist'
allowedModels.value = entries.map(([from]) => from)
modelMappings.value = []
} else {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
} else {
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
allowedModels.value = []
}
} else { } else {
const platformDefaultUrl = const platformDefaultUrl =
newAccount.platform === 'openai' newAccount.platform === 'openai'
@@ -3057,6 +3304,46 @@ const handleSubmit = async () => {
return return
} }
updatePayload.credentials = newCredentials
} else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
if (!editVertexProjectId.value.trim()) {
appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId'))
return
}
if (!editVertexClientEmail.value.trim()) {
appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail'))
return
}
if (!editVertexLocation.value.trim()) {
appStore.showError(t('admin.accounts.vertexLocationRequired'))
return
}
if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
return
}
newCredentials.project_id = editVertexProjectId.value.trim()
newCredentials.client_email = editVertexClientEmail.value.trim()
newCredentials.location = editVertexLocation.value.trim()
newCredentials.tier_id = 'vertex'
// Add model mapping if configured
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
delete newCredentials.model_mapping
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
return
}
updatePayload.credentials = newCredentials updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock') { } else if (props.account.type === 'bedrock') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {} const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}

View File

@@ -57,6 +57,19 @@ function makeAccount(overrides: Partial<Account>): Account {
describe('AccountUsageCell', () => { describe('AccountUsageCell', () => {
beforeEach(() => { beforeEach(() => {
getUsage.mockReset() getUsage.mockReset()
Object.defineProperty(window, 'matchMedia', {
writable: true,
value: vi.fn().mockImplementation(() => ({
matches: true,
media: '(min-width: 768px)',
onchange: null,
addListener: vi.fn(),
removeListener: vi.fn(),
addEventListener: vi.fn(),
removeEventListener: vi.fn(),
dispatchEvent: vi.fn(),
}))
})
}) })
it('Antigravity 图片用量会聚合新旧 image 模型', async () => { it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
@@ -603,4 +616,43 @@ describe('AccountUsageCell', () => {
expect(wrapper.text().trim()).toBe('-') expect(wrapper.text().trim()).toBe('-')
}) })
it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => {
const wrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 4001,
platform: 'gemini',
type: 'service_account',
credentials: {
tier_id: 'vertex',
project_id: 'vertex-proj',
client_email: 'svc@vertex-proj.iam.gserviceaccount.com',
location: 'global'
},
extra: {}
}),
todayStats: {
requests: 0,
tokens: 0,
cost: 0,
standard_cost: 0,
user_cost: 0
}
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(wrapper.text()).toContain('0 req')
expect(wrapper.text()).toContain('0')
expect(wrapper.text()).toContain('A $0.00')
expect(wrapper.text()).toContain('U $0.00')
})
}) })

View File

@@ -178,6 +178,45 @@ describe('BulkEditAccountModal', () => {
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false) expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
}) })
it('OpenAI OAuth 批量编辑应提交 codex_cli_only 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
})
await wrapper.get('#bulk-edit-openai-codex-cli-only-enabled').setValue(true)
await wrapper.get('#bulk-edit-openai-codex-cli-only-toggle').trigger('click')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
codex_cli_only: true
}
})
})
it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['apikey']
})
await wrapper.get('#bulk-edit-openai-apikey-ws-mode-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-apikey-ws-mode-select"]').setValue('ctx_pool')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
openai_apikey_responses_websockets_v2_mode: 'ctx_pool',
openai_apikey_responses_websockets_v2_enabled: true
}
})
})
it('OpenAI 账号批量编辑可关闭自动透传', async () => { it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({ const wrapper = mountModal({
selectedPlatforms: ['openai'], selectedPlatforms: ['openai'],
@@ -217,4 +256,41 @@ describe('BulkEditAccountModal', () => {
}) })
expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough') expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough')
}) })
it('filtered-results 模式下应提交 filters 而不是 account_ids', async () => {
const wrapper = mountModal({
accountIds: [],
target: {
mode: 'filtered',
filters: {
platform: 'openai',
type: 'oauth',
status: 'active',
group: '12',
search: 'bulk-target',
privacy_mode: 'training_set_cf_blocked'
},
previewCount: 5,
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
}
})
await wrapper.get('#bulk-edit-status-enabled').setValue(true)
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
filters: {
platform: 'openai',
type: 'oauth',
status: 'active',
group: '12',
search: 'bulk-target',
privacy_mode: 'training_set_cf_blocked'
},
status: 'active'
})
})
}) })

View File

@@ -1,9 +1,13 @@
<template> <template>
<div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg dark:bg-primary-900/20"> <div class="mb-4 flex items-center justify-between rounded-lg bg-primary-50 p-3 dark:bg-primary-900/20">
<div class="flex flex-wrap items-center gap-2"> <div class="flex flex-wrap items-center gap-2">
<span class="text-sm font-medium text-primary-900 dark:text-primary-100"> <span v-if="selectedIds.length > 0" class="text-sm font-medium text-primary-900 dark:text-primary-100">
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }} {{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
</span> </span>
<span v-else class="text-sm font-medium text-primary-900 dark:text-primary-100">
{{ t('admin.accounts.bulkEdit.title') }}
</span>
<template v-if="selectedIds.length > 0">
<button <button
@click="$emit('select-page')" @click="$emit('select-page')"
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200" class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
@@ -17,19 +21,25 @@
> >
{{ t('admin.accounts.bulkActions.clear') }} {{ t('admin.accounts.bulkActions.clear') }}
</button> </button>
</template>
</div> </div>
<div class="flex gap-2"> <div class="flex gap-2">
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button> <template v-if="selectedIds.length > 0">
<button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button> <button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
<button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button> <button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button>
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button> <button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button>
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button> <button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button> <button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
<button @click="$emit('edit-selected')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
</template>
<button @click="$emit('edit-filtered')" class="btn btn-primary btn-sm">
{{ t('admin.accounts.bulkEdit.submit') }}
</button>
</div> </div>
</div> </div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n() defineProps(['selectedIds']); defineEmits(['delete', 'edit-selected', 'edit-filtered', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n()
</script> </script>

View File

@@ -25,6 +25,7 @@
<!-- Setup Token icon --> <!-- Setup Token icon -->
<Icon v-else-if="type === 'setup-token'" name="shield" size="xs" /> <Icon v-else-if="type === 'setup-token'" name="shield" size="xs" />
<!-- API Key icon --> <!-- API Key icon -->
<Icon v-else-if="type === 'service_account'" name="cloud" size="xs" />
<Icon v-else name="key" size="xs" /> <Icon v-else name="key" size="xs" />
<span>{{ typeLabel }}</span> <span>{{ typeLabel }}</span>
</span> </span>
@@ -88,6 +89,8 @@ const typeLabel = computed(() => {
return 'Key' return 'Key'
case 'bedrock': case 'bedrock':
return 'AWS' return 'AWS'
case 'service_account':
return 'Vertex'
default: default:
return props.type return props.type
} }

View File

@@ -13,3 +13,51 @@ export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOT
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
/** Vertex AI location options for Service Account accounts */
export const VERTEX_LOCATION_OPTIONS = [
{
label: 'Common',
options: [
{ value: 'us-central1', label: 'us-central1 (Iowa)' },
{ value: 'global', label: 'global' },
{ value: 'us', label: 'us' },
{ value: 'eu', label: 'eu' }
]
},
{
label: 'United States',
options: [
{ value: 'us-east1', label: 'us-east1 (South Carolina)' },
{ value: 'us-east4', label: 'us-east4 (Northern Virginia)' },
{ value: 'us-east5', label: 'us-east5 (Columbus)' },
{ value: 'us-south1', label: 'us-south1 (Dallas)' },
{ value: 'us-west1', label: 'us-west1 (Oregon)' },
{ value: 'us-west4', label: 'us-west4 (Las Vegas)' }
]
},
{
label: 'Europe',
options: [
{ value: 'europe-west1', label: 'europe-west1 (Belgium)' },
{ value: 'europe-west2', label: 'europe-west2 (London)' },
{ value: 'europe-west3', label: 'europe-west3 (Frankfurt)' },
{ value: 'europe-west4', label: 'europe-west4 (Netherlands)' },
{ value: 'europe-west6', label: 'europe-west6 (Zurich)' },
{ value: 'europe-west8', label: 'europe-west8 (Milan)' },
{ value: 'europe-west9', label: 'europe-west9 (Paris)' }
]
},
{
label: 'Asia Pacific',
options: [
{ value: 'asia-east1', label: 'asia-east1 (Taiwan)' },
{ value: 'asia-east2', label: 'asia-east2 (Hong Kong)' },
{ value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' },
{ value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' },
{ value: 'asia-south1', label: 'asia-south1 (Mumbai)' },
{ value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' },
{ value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' }
]
}
] as const

View File

@@ -2815,6 +2815,26 @@ export default {
claudeConsole: 'Claude Console', claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock', bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 / API Key', bedrockDesc: 'SigV4 / API Key',
vertexLabel: 'Vertex',
vertexDesc: 'Service Account',
vertexAnthropicHint: 'Use a Google Cloud Service Account JSON to call Anthropic Claude via Vertex AI. It is recommended to configure model mapping to map client Claude model names to Vertex model IDs.',
vertexGeminiHint: 'Use a Google Cloud Service Account JSON to access Vertex AI Gemini. It is recommended to place Vertex accounts in a separate group to avoid mixing with AI Studio/Gemini OAuth on the same models.',
vertexSaJsonLabel: 'Service Account JSON',
vertexSaJsonLoaded: 'Service Account JSON loaded',
vertexSaJsonDrop: 'Drop Service Account JSON here',
vertexSaJsonKeyHidden: 'Key content is not displayed in the form.',
vertexSaJsonDropHint: 'Drag a .json file here, or click the button to select one.',
vertexSaJsonSelectBtn: 'Select JSON',
vertexSaJsonUploadHint: 'After uploading or dropping a JSON file, the project_id will be auto-extracted. Key content is only used for account creation.',
vertexSaJsonEditHint: 'Service Account JSON is not shown on the edit page; to change the JSON, delete the account and recreate it.',
vertexProjectIdPlaceholder: 'Auto-extracted from JSON',
vertexLocationHint: 'Available locations vary by Vertex model. Select the default endpoint location for this account.',
vertexLocationRequired: 'Please enter a Vertex location',
vertexSaJsonMissingFields: 'Service Account JSON is missing project_id, client_email, or private_key',
vertexSaJsonMissingProjectId: 'Service Account JSON is missing project_id',
vertexSaJsonMissingClientEmail: 'Service Account JSON is missing client_email',
vertexSaJsonInvalid: 'Service Account JSON format is invalid',
vertexSaJsonRequired: 'Please upload a Service Account JSON',
oauthSetupToken: 'OAuth / Setup Token', oauthSetupToken: 'OAuth / Setup Token',
addMethod: 'Add Method', addMethod: 'Add Method',
setupTokenLongLived: 'Setup Token (Long-lived)', setupTokenLongLived: 'Setup Token (Long-lived)',
@@ -4648,7 +4668,7 @@ export default {
errorLogRetentionDays: 'Error Log Retention Days', errorLogRetentionDays: 'Error Log Retention Days',
minuteMetricsRetentionDays: 'Minute Metrics Retention Days', minuteMetricsRetentionDays: 'Minute Metrics Retention Days',
hourlyMetricsRetentionDays: 'Hourly Metrics Retention Days', hourlyMetricsRetentionDays: 'Hourly Metrics Retention Days',
retentionDaysHint: 'Recommended 7-90 days, longer periods will consume more storage', retentionDaysHint: 'Recommended 7-90 days; longer periods consume more storage. Set to 0 to wipe all history on every scheduled cleanup',
aggregation: 'Pre-aggregation Tasks', aggregation: 'Pre-aggregation Tasks',
enableAggregation: 'Enable Pre-aggregation', enableAggregation: 'Enable Pre-aggregation',
aggregationHint: 'Pre-aggregation improves query performance for long time windows', aggregationHint: 'Pre-aggregation improves query performance for long time windows',
@@ -4678,7 +4698,7 @@ export default {
autoRefreshCountdown: 'Auto refresh: {seconds}s', autoRefreshCountdown: 'Auto refresh: {seconds}s',
validation: { validation: {
title: 'Please fix the following issues', title: 'Please fix the following issues',
retentionDaysRange: 'Retention days must be between 1-365 days', retentionDaysRange: 'Retention days must be between 0 and 365 (0 = wipe all on every cleanup)',
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100', slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0', ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100', requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
@@ -5535,6 +5555,38 @@ export default {
presetOpusOnlyDesc: 'Pass for Opus, filter others', presetOpusOnlyDesc: 'Pass for Opus, filter others',
commonPatterns: 'Common patterns' commonPatterns: 'Common patterns'
}, },
openaiFastPolicy: {
title: 'OpenAI Fast/Flex Policy',
description: 'Intercept, filter, or pass OpenAI fast(priority) / flex requests based on the request body service_tier field. Applies to the OpenAI gateway only.',
empty: 'No rules configured. Click the button below to add one.',
ruleHeader: 'Rule #{index}',
removeRule: 'Remove rule',
addRule: 'Add rule',
saveHint: 'Saved together with system settings (click the global Save button at the bottom of the page).',
serviceTier: 'service_tier match',
tierAll: 'All tiers',
tierPriority: 'priority (fast)',
tierFlex: 'flex',
action: 'Action',
actionPass: 'Pass (keep service_tier)',
actionFilter: 'Filter (remove service_tier)',
actionBlock: 'Block (reject request)',
scope: 'Scope',
scopeAll: 'All accounts',
scopeOAuth: 'OAuth only',
scopeAPIKey: 'API Key only',
scopeBedrock: 'Bedrock only',
errorMessage: 'Error message',
errorMessagePlaceholder: 'Custom error message when blocked',
errorMessageHint: 'Leave empty for the default message.',
modelWhitelist: 'Model whitelist',
modelWhitelistHint: 'Leave empty to apply to all models. Supports exact match and wildcard prefix (e.g., gpt-5.5*).',
modelPatternPlaceholder: 'e.g., gpt-5.5 or gpt-5.5*',
addModelPattern: 'Add model pattern',
fallbackAction: 'Fallback action',
fallbackActionHint: 'Action for models not matching the whitelist.',
fallbackErrorMessagePlaceholder: 'Custom error message when non-whitelisted models are blocked'
},
wechatConnect: { wechatConnect: {
title: 'WeChat Connect', title: 'WeChat Connect',
description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.', description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.',

View File

@@ -2963,6 +2963,26 @@ export default {
claudeConsole: 'Claude Console', claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock', bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 / API Key', bedrockDesc: 'SigV4 / API Key',
vertexLabel: 'Vertex',
vertexDesc: 'Service Account',
vertexAnthropicHint: '使用 Google Cloud Service Account JSON 通过 Vertex AI 调用 Anthropic Claude。建议配置模型映射将客户端 Claude 模型名映射到 Vertex 模型 ID。',
vertexGeminiHint: '使用 Google Cloud Service Account JSON 访问 Vertex AI Gemini。建议将 Vertex 账号放入独立分组,避免和 AI Studio/Gemini OAuth 同模型混调。',
vertexSaJsonLabel: 'Service Account JSON',
vertexSaJsonLoaded: '已读取 Service Account JSON',
vertexSaJsonDrop: '拖入 Service Account JSON',
vertexSaJsonKeyHidden: '密钥内容不会在表单中显示。',
vertexSaJsonDropHint: '把 .json 文件拖到这里,或点击按钮选择文件。',
vertexSaJsonSelectBtn: '选择 JSON',
vertexSaJsonUploadHint: '上传或拖入 JSON 后会自动读取 project_id密钥内容仅用于创建账号提交。',
vertexSaJsonEditHint: 'Service Account JSON 不在编辑页显示;需要更换 JSON 时请删除账号后重新创建。',
vertexProjectIdPlaceholder: '从 JSON 自动读取',
vertexLocationHint: '不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。',
vertexLocationRequired: '请填写 Vertex location',
vertexSaJsonMissingFields: 'Service Account JSON 缺少 project_id、client_email 或 private_key',
vertexSaJsonMissingProjectId: 'Service Account JSON 缺少 project_id',
vertexSaJsonMissingClientEmail: 'Service Account JSON 缺少 client_email',
vertexSaJsonInvalid: 'Service Account JSON 格式无效',
vertexSaJsonRequired: '请上传 Service Account JSON',
oauthSetupToken: 'OAuth / Setup Token', oauthSetupToken: 'OAuth / Setup Token',
addMethod: '添加方式', addMethod: '添加方式',
setupTokenLongLived: 'Setup Token长期有效', setupTokenLongLived: 'Setup Token长期有效',
@@ -4810,7 +4830,7 @@ export default {
errorLogRetentionDays: '错误日志保留天数', errorLogRetentionDays: '错误日志保留天数',
minuteMetricsRetentionDays: '分钟指标保留天数', minuteMetricsRetentionDays: '分钟指标保留天数',
hourlyMetricsRetentionDays: '小时指标保留天数', hourlyMetricsRetentionDays: '小时指标保留天数',
retentionDaysHint: '建议保留7-90天过长会占用存储空间', retentionDaysHint: '建议保留 7-90 天,过长会占用存储空间;填 0 表示每次定时清理时清空所有历史',
aggregation: '预聚合任务', aggregation: '预聚合任务',
enableAggregation: '启用预聚合任务', enableAggregation: '启用预聚合任务',
aggregationHint: '预聚合可提升长时间窗口查询性能', aggregationHint: '预聚合可提升长时间窗口查询性能',
@@ -4841,7 +4861,7 @@ export default {
autoRefreshCountdown: '自动刷新:{seconds}s', autoRefreshCountdown: '自动刷新:{seconds}s',
validation: { validation: {
title: '请先修正以下问题', title: '请先修正以下问题',
retentionDaysRange: '保留天数必须在1-365天之间', retentionDaysRange: '保留天数必须在 0-365 天之间0 = 每次清理时清空所有)',
slaMinPercentRange: 'SLA最低百分比必须在0-100之间', slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0', ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间', requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
@@ -5695,6 +5715,38 @@ export default {
presetOpusOnlyDesc: 'Opus 透传,其他模型过滤', presetOpusOnlyDesc: 'Opus 透传,其他模型过滤',
commonPatterns: '常用模式' commonPatterns: '常用模式'
}, },
openaiFastPolicy: {
title: 'OpenAI Fast/Flex 策略',
description: '基于请求体 service_tier 字段拦截/过滤/透传 OpenAI fast(priority) 与 flex 请求;仅作用于 OpenAI 网关。',
empty: '尚未配置任何规则。点击下方按钮新增。',
ruleHeader: '规则 #{index}',
removeRule: '删除规则',
addRule: '新增规则',
saveHint: '保存时随系统设置一起提交(点击页面底部「保存」按钮)。',
serviceTier: 'service_tier 匹配',
tierAll: '全部 tier',
tierPriority: 'priorityfast',
tierFlex: 'flex',
action: '处理方式',
actionPass: '透传(保留 service_tier',
actionFilter: '过滤(移除 service_tier',
actionBlock: '拦截(拒绝请求)',
scope: '生效范围',
scopeAll: '全部账号',
scopeOAuth: '仅 OAuth 账号',
scopeAPIKey: '仅 API Key 账号',
scopeBedrock: '仅 Bedrock 账号',
errorMessage: '错误消息',
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
errorMessageHint: '留空则使用默认错误消息。',
modelWhitelist: '模型白名单',
modelWhitelistHint: '留空表示对所有模型生效;支持精确匹配与通配符(如 gpt-5.5*)。',
modelPatternPlaceholder: '例如: gpt-5.5 或 gpt-5.5*',
addModelPattern: '添加模型规则',
fallbackAction: '未匹配模型处理方式',
fallbackActionHint: '当请求模型不在白名单中时的处理方式。',
fallbackErrorMessagePlaceholder: '未匹配模型被拦截时返回的自定义错误消息'
},
wechatConnect: { wechatConnect: {
title: '微信登录', title: '微信登录',
description: '用于微信开放平台或公众号/小程序的第三方登录配置。', description: '用于微信开放平台或公众号/小程序的第三方登录配置。',

View File

@@ -643,7 +643,7 @@ export interface UpdateGroupRequest {
// ==================== Account & Proxy Types ==================== // ==================== Account & Proxy Types ====================
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'service_account'
export type OAuthAddMethod = 'oauth' | 'setup-token' export type OAuthAddMethod = 'oauth' | 'setup-token'
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'

View File

@@ -1,93 +1,18 @@
/** /**
* Usage request scheduler — throttles Anthropic API calls by proxy exit. * Usage request scheduler.
* *
* Anthropic OAuth/setup-token accounts sharing the same proxy exit are placed * All platforms execute immediately without queuing the backend uses
* into a serial queue with a random 12s delay between requests, preventing * passive sampling so upstream 429 rate-limit errors are no longer a concern.
* upstream 429 rate-limit errors.
*
* Proxy identity = host:port:username — two proxy records pointing to the
* same exit share a single queue. Accounts without a proxy go into a
* "direct" queue.
*
* All other platforms bypass the queue and execute immediately.
*/ */
import type { Account } from '@/types' import type { Account } from '@/types'
const GROUP_DELAY_MIN_MS = 1000
const GROUP_DELAY_MAX_MS = 2000
type Task<T> = {
fn: () => Promise<T>
resolve: (value: T) => void
reject: (reason: unknown) => void
}
const queues = new Map<string, Task<unknown>[]>()
const running = new Set<string>()
/** Whether this account needs throttled queuing. */
function needsThrottle(account: Account): boolean {
return (
account.platform === 'anthropic' &&
(account.type === 'oauth' || account.type === 'setup-token')
)
}
/** Build a queue key from proxy connection details. */
function buildGroupKey(account: Account): string {
const proxy = account.proxy
const proxyIdentity = proxy
? `${proxy.host}:${proxy.port}:${proxy.username || ''}`
: 'direct'
return `anthropic:${proxyIdentity}`
}
async function drain(groupKey: string) {
if (running.has(groupKey)) return
running.add(groupKey)
const queue = queues.get(groupKey)
while (queue && queue.length > 0) {
const task = queue.shift()!
try {
const result = await task.fn()
task.resolve(result)
} catch (err) {
task.reject(err)
}
if (queue.length > 0) {
const jitter = GROUP_DELAY_MIN_MS + Math.random() * (GROUP_DELAY_MAX_MS - GROUP_DELAY_MIN_MS)
await new Promise((r) => setTimeout(r, jitter))
}
}
running.delete(groupKey)
queues.delete(groupKey)
}
/** /**
* Schedule a usage fetch. Anthropic accounts are queued by proxy exit; * Schedule a usage fetch. All requests execute immediately.
* all other platforms execute immediately.
*/ */
export function enqueueUsageRequest<T>( export function enqueueUsageRequest<T>(
account: Account, _account: Account,
fn: () => Promise<T> fn: () => Promise<T>
): Promise<T> { ): Promise<T> {
// Non-Anthropic → fire immediately, no queuing return fn()
if (!needsThrottle(account)) {
return fn()
}
const key = buildGroupKey(account)
return new Promise<T>((resolve, reject) => {
let queue = queues.get(key)
if (!queue) {
queue = []
queues.set(key, queue)
}
queue.push({ fn, resolve, reject } as Task<unknown>)
drain(key)
})
} }

View File

@@ -141,7 +141,17 @@
</div> </div>
</template> </template>
<template #table> <template #table>
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @reset-status="handleBulkResetStatus" @refresh-token="handleBulkRefreshToken" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" /> <AccountBulkActionsBar
:selected-ids="selIds"
@delete="handleBulkDelete"
@reset-status="handleBulkResetStatus"
@refresh-token="handleBulkRefreshToken"
@edit-selected="openBulkEditSelected"
@edit-filtered="openBulkEditFiltered"
@clear="clearSelection"
@select-page="selectPage"
@toggle-schedulable="handleBulkToggleSchedulable"
/>
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden"> <div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
<DataTable <DataTable
ref="dataTableRef" ref="dataTableRef"
@@ -303,7 +313,17 @@
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @schedule="handleSchedule" @reauth="handleReAuth" @refresh-token="handleRefresh" @recover-state="handleRecoverState" @reset-quota="handleResetQuota" @set-privacy="handleSetPrivacy" /> <AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @schedule="handleSchedule" @reauth="handleReAuth" @refresh-token="handleRefresh" @recover-state="handleRecoverState" @reset-quota="handleResetQuota" @set-privacy="handleSetPrivacy" />
<SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" /> <SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" />
<ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" /> <ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" />
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :selected-platforms="selPlatforms" :selected-types="selTypes" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" /> <BulkEditAccountModal
:show="showBulkEdit"
:account-ids="selIds"
:selected-platforms="selPlatforms"
:selected-types="selTypes"
:target="bulkEditTarget ?? undefined"
:proxies="proxies"
:groups="groups"
@close="showBulkEdit = false"
@updated="handleBulkUpdated"
/>
<TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" /> <TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" />
<ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" /> <ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" />
<ConfirmDialog :show="showExportDataDialog" :title="t('admin.accounts.dataExport')" :message="t('admin.accounts.dataExportConfirmMessage')" :confirm-text="t('admin.accounts.dataExportConfirm')" :cancel-text="t('common.cancel')" @confirm="handleExportData" @cancel="showExportDataDialog = false"> <ConfirmDialog :show="showExportDataDialog" :title="t('admin.accounts.dataExport')" :message="t('admin.accounts.dataExportConfirmMessage')" :confirm-text="t('admin.accounts.dataExportConfirm')" :cancel-text="t('common.cancel')" @confirm="handleExportData" @cancel="showExportDataDialog = false">
@@ -364,6 +384,29 @@ const proxies = ref<AccountProxy[]>([])
const groups = ref<AdminGroup[]>([]) const groups = ref<AdminGroup[]>([])
const accountTableRef = ref<HTMLElement | null>(null) const accountTableRef = ref<HTMLElement | null>(null)
const dataTableRef = ref<InstanceType<typeof DataTable> | null>(null) const dataTableRef = ref<InstanceType<typeof DataTable> | null>(null)
type AccountBulkEditTarget =
| {
mode: 'selected'
accountIds: number[]
selectedPlatforms: AccountPlatform[]
selectedTypes: AccountType[]
}
| {
mode: 'filtered'
filters: {
platform?: string
type?: string
status?: string
group?: string
search?: string
privacy_mode?: string
sort_by?: string
sort_order?: AccountSortOrder
}
previewCount: number
selectedPlatforms: AccountPlatform[]
selectedTypes: AccountType[]
}
const selPlatforms = computed<AccountPlatform[]>(() => { const selPlatforms = computed<AccountPlatform[]>(() => {
const platforms = new Set( const platforms = new Set(
accounts.value accounts.value
@@ -387,6 +430,7 @@ const showImportData = ref(false)
const showExportDataDialog = ref(false) const showExportDataDialog = ref(false)
const includeProxyOnExport = ref(true) const includeProxyOnExport = ref(true)
const showBulkEdit = ref(false) const showBulkEdit = ref(false)
const bulkEditTarget = ref<AccountBulkEditTarget | null>(null)
const showTempUnsched = ref(false) const showTempUnsched = ref(false)
const showDeleteDialog = ref(false) const showDeleteDialog = ref(false)
const showReAuth = ref(false) const showReAuth = ref(false)
@@ -1216,7 +1260,57 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
appStore.showError(t('common.error')) appStore.showError(t('common.error'))
} }
} }
const handleBulkUpdated = () => { showBulkEdit.value = false; clearSelection(); reload() } const buildBulkEditFilterSnapshot = () => {
const rawParams = toRaw(params) as Record<string, unknown>
const sortOrder: AccountSortOrder = rawParams.sort_order === 'desc' ? 'desc' : 'asc'
return {
platform: typeof rawParams.platform === 'string' ? rawParams.platform : '',
type: typeof rawParams.type === 'string' ? rawParams.type : '',
status: typeof rawParams.status === 'string' ? rawParams.status : '',
group: typeof rawParams.group === 'string' ? rawParams.group : '',
search: typeof rawParams.search === 'string' ? rawParams.search : '',
privacy_mode: typeof rawParams.privacy_mode === 'string' ? rawParams.privacy_mode : '',
sort_by: typeof rawParams.sort_by === 'string' ? rawParams.sort_by : '',
sort_order: sortOrder
}
}
const collectSelectionMetadata = (rows: Account[]) => {
const selectedPlatforms = Array.from(new Set(rows.map(account => account.platform)))
const selectedTypes = Array.from(new Set(rows.map(account => account.type)))
return { selectedPlatforms, selectedTypes }
}
const openBulkEditSelected = () => {
bulkEditTarget.value = {
mode: 'selected',
accountIds: [...selIds.value],
selectedPlatforms: [...selPlatforms.value],
selectedTypes: [...selTypes.value]
}
showBulkEdit.value = true
}
const openBulkEditFiltered = async () => {
const filters = buildBulkEditFilterSnapshot()
const preview = await adminAPI.accounts.list(1, 100, filters)
const { selectedPlatforms, selectedTypes } = collectSelectionMetadata(preview.items)
bulkEditTarget.value = {
mode: 'filtered',
filters,
previewCount: preview.total,
selectedPlatforms,
selectedTypes
}
showBulkEdit.value = true
}
const handleBulkUpdated = () => {
showBulkEdit.value = false
bulkEditTarget.value = null
clearSelection()
reload()
}
const handleDataImported = () => { showImportData.value = false; reload() } const handleDataImported = () => { showImportData.value = false; reload() }
const ACCOUNT_UNGROUPED_GROUP_QUERY_VALUE = 'ungrouped' const ACCOUNT_UNGROUPED_GROUP_QUERY_VALUE = 'ungrouped'
const ACCOUNT_PRIVACY_MODE_UNSET_QUERY_VALUE = '__unset__' const ACCOUNT_PRIVACY_MODE_UNSET_QUERY_VALUE = '__unset__'

View File

@@ -949,6 +949,285 @@
</template> </template>
</div> </div>
</div> </div>
<!-- OpenAI Fast/Flex Policy Settings -->
<div class="card">
<div
class="border-b border-gray-100 px-6 py-4 dark:border-dark-700"
>
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
{{ t("admin.settings.openaiFastPolicy.title") }}
</h2>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t("admin.settings.openaiFastPolicy.description") }}
</p>
</div>
<div class="space-y-5 p-6">
<!-- Empty state -->
<div
v-if="openaiFastPolicyForm.rules.length === 0"
class="rounded-lg border border-dashed border-gray-200 p-6 text-center text-sm text-gray-500 dark:border-dark-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.empty") }}
</div>
<!-- Rule Cards -->
<div
v-for="(rule, ruleIndex) in openaiFastPolicyForm.rules"
:key="ruleIndex"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center justify-between">
<span
class="text-sm font-medium text-gray-900 dark:text-white"
>
{{
t("admin.settings.openaiFastPolicy.ruleHeader", {
index: ruleIndex + 1,
})
}}
</span>
<button
type="button"
@click="removeOpenAIFastPolicyRule(ruleIndex)"
class="rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
:title="t('admin.settings.openaiFastPolicy.removeRule')"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<div class="grid grid-cols-1 gap-4 md:grid-cols-3">
<!-- Service Tier -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.serviceTier") }}
</label>
<Select
:modelValue="rule.service_tier"
@update:modelValue="
rule.service_tier = $event as
| 'all'
| 'priority'
| 'flex'
"
:options="openaiFastPolicyTierOptions"
/>
</div>
<!-- Action -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.action") }}
</label>
<Select
:modelValue="rule.action"
@update:modelValue="
rule.action = $event as 'pass' | 'filter' | 'block'
"
:options="openaiFastPolicyActionOptions"
/>
</div>
<!-- Scope -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.scope") }}
</label>
<Select
:modelValue="rule.scope"
@update:modelValue="
rule.scope = $event as
| 'all'
| 'oauth'
| 'apikey'
| 'bedrock'
"
:options="openaiFastPolicyScopeOptions"
/>
</div>
</div>
<!-- Error Message (only when action=block) -->
<div v-if="rule.action === 'block'" class="mt-3">
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.errorMessage") }}
</label>
<input
v-model="rule.error_message"
type="text"
class="input"
:placeholder="
t(
'admin.settings.openaiFastPolicy.errorMessagePlaceholder',
)
"
/>
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{{ t("admin.settings.openaiFastPolicy.errorMessageHint") }}
</p>
</div>
<!-- Model Whitelist -->
<div class="mt-3">
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.modelWhitelist") }}
</label>
<p class="mb-2 text-xs text-gray-400 dark:text-gray-500">
{{
t("admin.settings.openaiFastPolicy.modelWhitelistHint")
}}
</p>
<div
v-for="(_, patternIdx) in rule.model_whitelist || []"
:key="patternIdx"
class="mb-1.5 flex items-center gap-2"
>
<input
v-model="rule.model_whitelist![patternIdx]"
type="text"
class="input input-sm flex-1"
:placeholder="
t(
'admin.settings.openaiFastPolicy.modelPatternPlaceholder',
)
"
/>
<button
type="button"
@click="
removeOpenAIFastPolicyModelPattern(rule, patternIdx)
"
class="shrink-0 rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<button
type="button"
@click="addOpenAIFastPolicyModelPattern(rule)"
class="mb-2 inline-flex items-center gap-1 text-xs text-primary-600 transition-colors hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
>
<svg
class="h-3.5 w-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t("admin.settings.openaiFastPolicy.addModelPattern") }}
</button>
</div>
<!-- Fallback Action (only when model_whitelist is non-empty) -->
<div
v-if="
rule.model_whitelist && rule.model_whitelist.length > 0
"
class="mt-3"
>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.fallbackAction") }}
</label>
<Select
:modelValue="rule.fallback_action || 'pass'"
@update:modelValue="
rule.fallback_action = $event as
| 'pass'
| 'filter'
| 'block'
"
:options="openaiFastPolicyActionOptions"
/>
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{{
t("admin.settings.openaiFastPolicy.fallbackActionHint")
}}
</p>
<div v-if="rule.fallback_action === 'block'" class="mt-2">
<input
v-model="rule.fallback_error_message"
type="text"
class="input"
:placeholder="
t(
'admin.settings.openaiFastPolicy.fallbackErrorMessagePlaceholder',
)
"
/>
</div>
</div>
</div>
<!-- Add Rule Button -->
<div>
<button
type="button"
@click="addOpenAIFastPolicyRule"
class="btn btn-secondary btn-sm inline-flex items-center gap-1"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t("admin.settings.openaiFastPolicy.addRule") }}
</button>
<p class="mt-2 text-xs text-gray-400 dark:text-gray-500">
{{ t("admin.settings.openaiFastPolicy.saveHint") }}
</p>
</div>
</div>
</div>
</div> </div>
<!-- /Tab: Gateway --> <!-- /Tab: Gateway -->
@@ -5199,6 +5478,7 @@ import type {
SystemSettings, SystemSettings,
UpdateSettingsRequest, UpdateSettingsRequest,
DefaultSubscriptionSetting, DefaultSubscriptionSetting,
OpenAIFastPolicyRule,
WeChatConnectMode, WeChatConnectMode,
WebSearchEmulationConfig, WebSearchEmulationConfig,
WebSearchProviderConfig, WebSearchProviderConfig,
@@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({
}>, }>,
}); });
// OpenAI Fast/Flex Policy 状态
const openaiFastPolicyForm = reactive({
rules: [] as OpenAIFastPolicyRule[],
});
// 标记 openai_fast_policy_settings 是否已成功从后端加载,
// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。
const openaiFastPolicyLoaded = ref(false);
const tablePageSizeMin = 5; const tablePageSizeMin = 5;
const tablePageSizeMax = 1000; const tablePageSizeMax = 1000;
const tablePageSizeDefault = 20; const tablePageSizeDefault = 20;
@@ -6116,6 +6404,23 @@ async function loadSettings() {
); );
form.oidc_connect_client_secret = ""; form.oidc_connect_client_secret = "";
// Load OpenAI fast/flex policy rules from bulk settings.
// 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值,
// 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。
if (
settings.openai_fast_policy_settings &&
Array.isArray(settings.openai_fast_policy_settings.rules)
) {
openaiFastPolicyForm.rules =
settings.openai_fast_policy_settings.rules.map((rule) => ({
...rule,
model_whitelist: rule.model_whitelist
? [...rule.model_whitelist]
: [],
}));
openaiFastPolicyLoaded.value = true;
}
// Load web search emulation config separately // Load web search emulation config separately
await loadWebSearchConfig(); await loadWebSearchConfig();
} catch (error: unknown) { } catch (error: unknown) {
@@ -6460,10 +6765,39 @@ async function saveSettings() {
affiliate_enabled: form.affiliate_enabled, affiliate_enabled: form.affiliate_enabled,
}; };
// 仅当 openai_fast_policy_settings 已成功从后端加载时才回写,
// 否则省略整个字段,让后端保留既有规则(含默认值)。
if (openaiFastPolicyLoaded.value) {
payload.openai_fast_policy_settings = {
rules: openaiFastPolicyForm.rules.map((rule) => {
const whitelist = (rule.model_whitelist || [])
.map((p) => p.trim())
.filter((p) => p !== "");
const hasWhitelist = whitelist.length > 0;
return {
service_tier: rule.service_tier,
action: rule.action,
scope: rule.scope,
error_message:
rule.action === "block" ? rule.error_message : undefined,
model_whitelist: hasWhitelist ? whitelist : undefined,
fallback_action: hasWhitelist
? rule.fallback_action || "pass"
: undefined,
fallback_error_message:
hasWhitelist && rule.fallback_action === "block"
? rule.fallback_error_message
: undefined,
};
}),
};
}
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults); appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
const updated = await adminAPI.settings.updateSettings(payload); const updated = await adminAPI.settings.updateSettings(payload);
for (const [key, value] of Object.entries(updated)) { for (const [key, value] of Object.entries(updated)) {
if (key === "openai_fast_policy_settings") continue;
if (value !== null && value !== undefined) { if (value !== null && value !== undefined) {
(form as Record<string, unknown>)[key] = value; (form as Record<string, unknown>)[key] = value;
} }
@@ -6507,6 +6841,20 @@ async function saveSettings() {
form.wechat_connect_mode, form.wechat_connect_mode,
); );
form.oidc_connect_client_secret = ""; form.oidc_connect_client_secret = "";
// Refresh OpenAI fast/flex policy from server response
if (
updated.openai_fast_policy_settings &&
Array.isArray(updated.openai_fast_policy_settings.rules)
) {
openaiFastPolicyForm.rules =
updated.openai_fast_policy_settings.rules.map((rule) => ({
...rule,
model_whitelist: rule.model_whitelist
? [...rule.model_whitelist]
: [],
}));
openaiFastPolicyLoaded.value = true;
}
// Save web search emulation config separately (errors handled internally) // Save web search emulation config separately (errors handled internally)
const wsOk = await saveWebSearchConfig(); const wsOk = await saveWebSearchConfig();
// Refresh cached settings so sidebar/header update immediately // Refresh cached settings so sidebar/header update immediately
@@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() {
} }
} }
// ==================== OpenAI Fast/Flex Policy ====================
const openaiFastPolicyTierOptions = computed(() => [
{ value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") },
{
value: "priority",
label: t("admin.settings.openaiFastPolicy.tierPriority"),
},
{ value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") },
]);
const openaiFastPolicyActionOptions = computed(() => [
{ value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") },
{ value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") },
{ value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") },
]);
const openaiFastPolicyScopeOptions = computed(() => [
{ value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") },
{ value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") },
{ value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") },
{
value: "bedrock",
label: t("admin.settings.openaiFastPolicy.scopeBedrock"),
},
]);
function addOpenAIFastPolicyRule() {
openaiFastPolicyForm.rules.push({
service_tier: "priority",
action: "filter",
scope: "all",
error_message: "",
model_whitelist: [],
fallback_action: "pass",
fallback_error_message: "",
});
}
function removeOpenAIFastPolicyRule(index: number) {
openaiFastPolicyForm.rules.splice(index, 1);
}
function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) {
if (!rule.model_whitelist) rule.model_whitelist = [];
rule.model_whitelist.push("");
}
function removeOpenAIFastPolicyModelPattern(
rule: OpenAIFastPolicyRule,
idx: number,
) {
rule.model_whitelist?.splice(idx, 1);
}
async function saveBetaPolicySettings() { async function saveBetaPolicySettings() {
betaPolicySaving.value = true; betaPolicySaving.value = true;
try { try {

View File

@@ -0,0 +1,152 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
import AccountsView from '../AccountsView.vue'
const {
listAccounts,
listWithEtag,
getBatchTodayStats,
getAllProxies,
getAllGroups
} = vi.hoisted(() => ({
listAccounts: vi.fn(),
listWithEtag: vi.fn(),
getBatchTodayStats: vi.fn(),
getAllProxies: vi.fn(),
getAllGroups: vi.fn()
}))
vi.mock('@/api/admin', () => ({
adminAPI: {
accounts: {
list: listAccounts,
listWithEtag,
getBatchTodayStats,
delete: vi.fn(),
batchClearError: vi.fn(),
batchRefresh: vi.fn(),
toggleSchedulable: vi.fn()
},
proxies: {
getAll: getAllProxies
},
groups: {
getAll: getAllGroups
}
}
}))
vi.mock('@/stores/app', () => ({
useAppStore: () => ({
showError: vi.fn(),
showSuccess: vi.fn(),
showInfo: vi.fn()
})
}))
vi.mock('@/stores/auth', () => ({
useAuthStore: () => ({
token: 'test-token'
})
}))
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
const DataTableStub = {
props: ['columns', 'data'],
template: '<div data-test="data-table"></div>'
}
const AccountBulkActionsBarStub = {
props: ['selectedIds'],
emits: ['edit-filtered'],
template: '<button data-test="edit-filtered" @click="$emit(\'edit-filtered\')">edit filtered</button>'
}
const BulkEditAccountModalStub = {
props: ['show', 'target'],
template: '<div data-test="bulk-edit-modal" :data-show="String(show)" :data-target-mode="target?.mode ?? \'\'"></div>'
}
describe('admin AccountsView bulk edit scope', () => {
beforeEach(() => {
localStorage.clear()
listAccounts.mockReset()
listWithEtag.mockReset()
getBatchTodayStats.mockReset()
getAllProxies.mockReset()
getAllGroups.mockReset()
listAccounts.mockResolvedValue({
items: [],
total: 0,
page: 1,
page_size: 20,
pages: 0
})
listWithEtag.mockResolvedValue({
notModified: true,
etag: null,
data: null
})
getBatchTodayStats.mockResolvedValue({ stats: {} })
getAllProxies.mockResolvedValue([])
getAllGroups.mockResolvedValue([])
})
it('opens bulk edit in filtered-results mode from the bulk actions dropdown', async () => {
const wrapper = mount(AccountsView, {
global: {
stubs: {
AppLayout: { template: '<div><slot /></div>' },
TablePageLayout: {
template: '<div><slot name="filters" /><slot name="table" /><slot name="pagination" /></div>'
},
DataTable: DataTableStub,
Pagination: true,
ConfirmDialog: true,
AccountTableActions: { template: '<div><slot name="beforeCreate" /><slot name="after" /></div>' },
AccountTableFilters: { template: '<div></div>' },
AccountBulkActionsBar: AccountBulkActionsBarStub,
AccountActionMenu: true,
ImportDataModal: true,
ReAuthAccountModal: true,
AccountTestModal: true,
AccountStatsModal: true,
ScheduledTestsPanel: true,
SyncFromCrsModal: true,
TempUnschedStatusModal: true,
ErrorPassthroughRulesModal: true,
TLSFingerprintProfilesModal: true,
CreateAccountModal: true,
EditAccountModal: true,
BulkEditAccountModal: BulkEditAccountModalStub,
PlatformTypeBadge: true,
AccountCapacityCell: true,
AccountStatusIndicator: true,
AccountTodayStatsCell: true,
AccountGroupsCell: true,
AccountUsageCell: true,
Icon: true
}
}
})
await flushPromises()
await wrapper.get('[data-test="edit-filtered"]').trigger('click')
await flushPromises()
expect(wrapper.get('[data-test="bulk-edit-modal"]').attributes('data-show')).toBe('true')
expect(wrapper.get('[data-test="bulk-edit-modal"]').attributes('data-target-mode')).toBe('filtered')
})
})

View File

@@ -136,13 +136,13 @@ const validation = computed(() => {
// 验证高级设置 // 验证高级设置
if (advancedSettings.value) { if (advancedSettings.value) {
const { error_log_retention_days, minute_metrics_retention_days, hourly_metrics_retention_days } = advancedSettings.value.data_retention const { error_log_retention_days, minute_metrics_retention_days, hourly_metrics_retention_days } = advancedSettings.value.data_retention
if (error_log_retention_days < 1 || error_log_retention_days > 365) { if (error_log_retention_days < 0 || error_log_retention_days > 365) {
errors.push(t('admin.ops.settings.validation.retentionDaysRange')) errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
} }
if (minute_metrics_retention_days < 1 || minute_metrics_retention_days > 365) { if (minute_metrics_retention_days < 0 || minute_metrics_retention_days > 365) {
errors.push(t('admin.ops.settings.validation.retentionDaysRange')) errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
} }
if (hourly_metrics_retention_days < 1 || hourly_metrics_retention_days > 365) { if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) {
errors.push(t('admin.ops.settings.validation.retentionDaysRange')) errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
} }
} }
@@ -431,7 +431,7 @@ async function saveAllSettings() {
<input <input
v-model.number="advancedSettings.data_retention.error_log_retention_days" v-model.number="advancedSettings.data_retention.error_log_retention_days"
type="number" type="number"
min="1" min="0"
max="365" max="365"
class="input" class="input"
/> />
@@ -441,7 +441,7 @@ async function saveAllSettings() {
<input <input
v-model.number="advancedSettings.data_retention.minute_metrics_retention_days" v-model.number="advancedSettings.data_retention.minute_metrics_retention_days"
type="number" type="number"
min="1" min="0"
max="365" max="365"
class="input" class="input"
/> />
@@ -451,7 +451,7 @@ async function saveAllSettings() {
<input <input
v-model.number="advancedSettings.data_retention.hourly_metrics_retention_days" v-model.number="advancedSettings.data_retention.hourly_metrics_retention_days"
type="number" type="number"
min="1" min="0"
max="365" max="365"
class="input" class="input"
/> />