Compare commits

...

23 Commits

Author SHA1 Message Date
Wesley Liddick
c129825f9b Merge pull request #2116 from KnowSky404/fix/openai-bulk-edit-compact-config
fix: add OpenAI compact bulk edit fields
2026-05-04 00:14:46 +08:00
Wesley Liddick
ff50b8b6ea Merge pull request #2170 from deqiying/fix/openai-ws-passthrough-reasoning-effort
fix(openai): 修复 WS passthrough 用量记录缺失 reasoning effort 和 User-AgentFix/OpenAI ws passthrough reasoning effort
2026-05-04 00:13:42 +08:00
shaw
4cbf518f0a fix: preserve raw chat completions usage billing 2026-05-03 23:31:43 +08:00
Wesley Liddick
dc09b367dc Merge pull request #2143 from alfadb/fix/openai-apikey-cc-default-routing
修复:APIKey 账户上游不支持 OpenAI Responses API 时的 Chat Completions 路由回退
2026-05-03 22:58:26 +08:00
deqiying
11fe29223d Merge branch 'main' into fix/openai-ws-passthrough-reasoning-effort 2026-05-03 22:18:46 +08:00
shaw
0b84d12dbb fix: correct affiliate audit record sources 2026-05-03 22:12:57 +08:00
Wesley Liddick
76e2503d5e Merge pull request #2169 from lyen1688/feat/admin-affiliate-records
feat: 新增管理后台邀请返利记录页面
2026-05-03 21:10:33 +08:00
lyen1688
3ab40269b4 完善返利转入余额历史显示 2026-05-03 20:33:14 +08:00
lyen1688
650ddb2e39 fix: make affiliate record users clickable 2026-05-03 20:33:14 +08:00
lyen1688
0a914e034c fix: include matured affiliate quota in admin overview 2026-05-03 20:33:13 +08:00
lyen1688
6a41cf6a51 feat: add admin affiliate record pages 2026-05-03 20:33:13 +08:00
deqiying
23555be380 fix(openai): 修复 WS passthrough 使用记录缺失推理强度和 User-Agent
- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据
- 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力
- 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
2026-05-03 19:33:09 +08:00
shaw
47fb38bca1 fix: record zero OpenAI usage logs 2026-05-03 17:43:56 +08:00
shaw
72d5ee4cd1 fix: drain OpenAI compat streams for usage 2026-05-03 17:11:27 +08:00
shaw
b2bdba78dd stabilize image request handling 2026-05-03 14:56:09 +08:00
alfadb
3930bebaf9 Merge branch 'fix/raw-cc-upstream-endpoint-log' into fix/openai-apikey-cc-default-routing 2026-05-02 10:53:11 +08:00
alfadb
e736de1ed9 fix(handler): log correct upstream endpoint for raw CC path
DeriveUpstreamEndpoint hard-codes /v1/responses for PlatformOpenAI,
but APIKey accounts probed to not support Responses API are forwarded
directly to /v1/chat/completions via forwardAsRawChatCompletions.

Add resolveRawCCUpstreamEndpoint which returns /v1/chat/completions
when the account's extra.openai_responses_supported is explicitly false.
2026-05-02 10:31:57 +08:00
alfadb
57099a6af6 fix(openai-gateway): extract reasoning_effort in raw Chat Completions path
The forwardAsRawChatCompletions path (used when APIKey accounts target
upstreams that don't support Responses API, e.g. DeepSeek) was missing
reasoning_effort and service_tier extraction, causing all reasoning
effort values to be silently dropped.

Extract both from the raw Chat Completions body before forwarding, and
propagate them through streamRawChatCompletions / bufferRawChatCompletions
to the OpenAIForwardResult.
2026-05-02 10:22:16 +08:00
alfadb
adf01ac880 fix(openai-gateway): address PR review — probe URL /v1 prefix, Create trigger, tests
Fix four issues flagged by copilot-pull-request-reviewer on PR #2143:

1. Probe URL missing /v1 prefix (openai_apikey_responses_probe.go)
   Replaced bare TrimSuffix + "/responses" with buildOpenAIResponsesURL(),
   which handles bare domain → /v1/responses correctly. Affected:
   - ProbeOpenAIAPIKeyResponsesSupport (probe URL)
   - TestAccount endpoint (apiURL for APIKey accounts)

2. Create endpoint not triggering probe (account_handler.go)
   Capture created account from idempotent closure and call
   scheduleOpenAIResponsesProbe after success, same pattern as
   BatchCreate and Update.

3. Tests (openai_gateway_chat_completions_raw_test.go)
   Added TestBuildOpenAIChatCompletionsURL (7 cases covering
   bare domain, /v1 suffix, trailing slash, third-party domains,
   whitespace) and TestBuildOpenAIResponsesURL_ProbeURL (6 cases
   locking the probe URL construction for bare-domain inputs).

All unit tests pass; go build ./cmd/server/ clean.
2026-04-30 21:46:46 +08:00
alfadb-bot
4d145300c3 fixup! fix(openai-gateway): route APIKey accounts to /v1/chat/completions when upstream lacks Responses API
Address self-review findings:

R7: Use a narrow per-trust-domain header allowlist for CC raw forwarding.
The previously reused openaiAllowedHeaders contains Codex client-only headers
(originator/session_id/x-codex-turn-state/x-codex-turn-metadata/conversation_id)
that must not leak to third-party OpenAI-compatible upstreams (DeepSeek/Kimi/
GLM/Qwen). Strict upstreams may 400 with 'unknown parameter'; lenient ones
silently pollute their request statistics. New openaiCCRawAllowedHeaders only
allows generic HTTP headers (accept-language, user-agent); content-type/
authorization/accept are set explicitly by callers.

R4: Drop the dead includeUsage parameter from streamRawChatCompletions.
The CC pass-through path doesn't need to inspect the client's stream_options
flag — the upstream handles it and we only extract usage when it appears in
chunks. Killing the unused parameter removes a misleading 'parameter read
but discarded' code smell.

Sediment refs:
- pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
- pensieve/short-term/knowledge/openai-gateway-shared-state-quirks
- pensieve/short-term/pipelines/run-when-self-reviewing-forwarder-implementation
2026-04-30 20:16:44 +08:00
alfadb-bot
4e4cc80971 fix(openai-gateway): route APIKey accounts to /v1/chat/completions when upstream lacks Responses API
OpenAI APIKey accounts with base_url pointing to third-party OpenAI-compatible
upstreams (DeepSeek, Kimi, GLM, Qwen, etc.) were failing because the gateway
unconditionally converted Chat Completions requests to Responses format and
forwarded to {base_url}/v1/responses, which only exists on OpenAI's official
endpoint.

Detection-based routing:
- Probe upstream capability on account create/update via a minimal POST to
  /v1/responses; HTTP 404/405 means 'unsupported', any other response means
  'supported'.
- Persist result as accounts.extra.openai_responses_supported (bool).
- ForwardAsChatCompletions branches at function entry: APIKey accounts with
  explicit support=false go through new forwardAsRawChatCompletions which
  passthrough-forwards CC body to /v1/chat/completions without protocol
  conversion.

Default behavior for accounts without the marker preserves the legacy
'always Responses' path — existing OpenAI APIKey accounts that were working
before this change continue to work without modification (the 'reality is
evidence' principle: an account that has been running implies upstream
capability).

Probe is fired async after Create / Update / BatchCreate; failures only log,
never block the admin flow. BulkUpdate omitted (low signal of base_url
changes; can be added if needed).

Implementation:
- New pkg internal/pkg/openai_compat: marker key + ShouldUseResponsesAPI
- New service file openai_apikey_responses_probe.go: probe + persist
- New service file openai_gateway_chat_completions_raw.go: CC pass-through
- Account test endpoint short-circuits with explicit message for
  probed-unsupported accounts (full CC test path is a TODO)

Zero schema changes, zero migrations, zero frontend changes, zero wire
modifications — all wired through existing AccountTestService injection.

Closes: DeepSeek-OpenAI account (id=128) production failure
2026-04-30 19:25:45 +08:00
github-actions[bot]
48912014a1 chore: sync VERSION to 0.1.121 [skip ci] 2026-04-30 06:06:12 +00:00
KnowSky404
3953dc9ce4 fix: add OpenAI compact bulk edit fields 2026-04-30 10:19:59 +08:00
58 changed files with 4940 additions and 270 deletions

View File

@@ -1 +1 @@
0.1.120 0.1.121

View File

@@ -528,6 +528,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查 // 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
// 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。
// 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。
var createdAccount *service.Account
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name, Name: req.Name,
@@ -549,6 +553,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
if execErr != nil { if execErr != nil {
return nil, execErr return nil, execErr
} }
createdAccount = account
// Antigravity OAuth: 新账号直接设置隐私 // Antigravity OAuth: 新账号直接设置隐私
h.adminService.ForceAntigravityPrivacy(ctx, account) h.adminService.ForceAntigravityPrivacy(ctx, account)
// OpenAI OAuth: 新账号直接设置隐私 // OpenAI OAuth: 新账号直接设置隐私
@@ -577,6 +582,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
if result != nil && result.Replayed { if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true") c.Header("X-Idempotency-Replayed", "true")
} }
// OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。
// 探测失败不影响账号创建响应。
h.scheduleOpenAIResponsesProbe(createdAccount)
response.Success(c, result.Data) response.Success(c, result.Data)
} }
@@ -637,9 +645,39 @@ func (h *AccountHandler) Update(c *gin.Context) {
return return
} }
// OpenAI APIKey: credentials 修改后重新探测上游能力base_url/api_key 可能变更)。
// 异步执行,探测失败不影响账号更新响应。
if len(req.Credentials) > 0 {
h.scheduleOpenAIResponsesProbe(account)
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。
//
// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。
// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞
// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失,
// 网关会按"现状即证据"默认走 Responses。
func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) {
if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey {
return
}
if h.accountTestService == nil {
return
}
accountID := account.ID
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r)
}
}()
h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID)
}()
}
// Delete handles deleting an account // Delete handles deleting an account
// DELETE /api/v1/admin/accounts/:id // DELETE /api/v1/admin/accounts/:id
func (h *AccountHandler) Delete(c *gin.Context) { func (h *AccountHandler) Delete(c *gin.Context) {
@@ -1231,6 +1269,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account) openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
} }
} }
// OpenAI APIKey 账号异步探测 /v1/responses 能力。
h.scheduleOpenAIResponsesProbe(account)
success++ success++
results = append(results, gin.H{ results = append(results, gin.H{
"name": item.Name, "name": item.Name,

View File

@@ -2,8 +2,11 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
} }
response.Success(c, result) response.Success(c, result)
} }
// GetUserOverview returns one user's affiliate overview.
// GET /api/v1/admin/affiliates/users/:user_id/overview
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, overview)
}
// ListInviteRecords returns all inviter-invitee relationships.
// GET /api/v1/admin/affiliates/invites
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListRebateRecords returns all order-level affiliate rebate records.
// GET /api/v1/admin/affiliates/rebates
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
// GET /api/v1/admin/affiliates/transfers
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
filter := service.AffiliateRecordFilter{
Search: c.Query("search"),
Page: page,
PageSize: pageSize,
SortBy: c.Query("sort_by"),
SortDesc: c.Query("sort_order") != "asc",
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
userTZ := c.Query("timezone")
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
filter.StartAt = t
}
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
filter.EndAt = t
}
return filter
}
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
return &parsed
}
return nil
}
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
return &end
}
return nil
}

View File

@@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
// GetBalanceHistory handles getting user's balance/concurrency change history // GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history // GET /api/v1/admin/users/:id/balance-history
// Query params: // Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) // - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) { func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64) userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {

View File

@@ -10,6 +10,7 @@ import (
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -276,7 +277,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
@@ -299,3 +300,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return return
} }
} }
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
// has been probed to not support the Responses API, the request is
// forwarded directly to /v1/chat/completions — not through the default
// CC→Responses conversion path.
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
if account != nil && account.Type == service.AccountTypeAPIKey &&
!openai_compat.ShouldUseResponsesAPI(account.Extra) {
return "/v1/chat/completions"
}
return GetUpstreamEndpoint(c, account.Platform)
}

View File

@@ -1233,6 +1233,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
) )
hooks := &service.OpenAIWSIngressHooks{ hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeTurn: func(turn int) error { BeforeTurn: func(turn int) error {
if turn == 1 { if turn == 1 {
return nil return nil

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
} }
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
})
require.NotNil(t, got.log.UserAgent)
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "high", *got.log.ReasoningEffort)
require.True(t, got.log.OpenAIWSMode)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
channelMapping: map[string]string{
"gpt-5.4-xhigh": "gpt-5.4",
},
})
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
"上游首帧应使用渠道映射后的模型")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
userAgent: testStringPtr(""),
})
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "medium", *got.log.ReasoningEffort)
}
func TestSetOpenAIClientTransportHTTP(t *testing.T) { func TestSetOpenAIClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
router.GET("/openai/v1/responses", h.ResponsesWebSocket) router.GET("/openai/v1/responses", h.ResponsesWebSocket)
return httptest.NewServer(router) return httptest.NewServer(router)
} }
type openAIResponsesWSUsageLogCase struct {
firstPayload string
userAgent *string
channelMapping map[string]string
}
type openAIResponsesWSUsageLogResult struct {
log *service.UsageLog
upstreamFirstPayload []byte
}
type openAIWSUsageHandlerAccountRepoStub struct {
service.AccountRepository
account service.Account
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
if s.account.Platform != platform {
return nil, nil
}
return []service.Account{s.account}, nil
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if s.account.ID != id {
return nil, nil
}
account := s.account
return &account, nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository
created chan *service.UsageLog
}
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if s.created != nil {
s.created <- log
}
return true, nil
}
type openAIWSUsageHandlerChannelRepoStub struct {
service.ChannelRepository
channels []service.Channel
groupPlatforms map[int64]string
}
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
return s.channels, nil
}
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
out := make(map[int64]string, len(groupIDs))
for _, groupID := range groupIDs {
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
out[groupID] = platform
}
}
return out, nil
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper()
gin.SetMode(gin.TestMode)
upstreamPayloadCh := make(chan []byte, 1)
upstreamErrCh := make(chan error, 1)
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
upstreamErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
msgType, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr != nil {
upstreamErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
return
}
upstreamPayloadCh <- payload
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
))
cancelWrite()
if writeErr != nil {
upstreamErrCh <- writeErr
return
}
_ = conn.Close(coderws.StatusNormalClosure, "done")
upstreamErrCh <- nil
}))
defer upstreamServer.Close()
groupID := int64(4201)
account := service.Account{
ID: 9901,
Name: "openai-ws-passthrough-usage-e2e",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": upstreamServer.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
var channelSvc *service.ChannelService
if len(tc.channelMapping) > 0 {
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
channels: []service.Channel{{
ID: 7701,
Name: "openai-ws-e2e-channel",
Status: service.StatusActive,
GroupIDs: []int64{groupID},
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
}},
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
}, nil, nil, nil)
}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
usageRepo,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
nil,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
channelSvc,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
}
apiKey := &service.APIKey{
ID: 1801,
GroupID: &groupID,
User: &service.User{ID: 1701, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
headers := http.Header{}
if tc.userAgent != nil {
headers.Set("User-Agent", *tc.userAgent)
}
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
var usageLog *service.UsageLog
select {
case usageLog = <-usageRepo.created:
require.NotNil(t, usageLog)
case <-time.After(3 * time.Second):
t.Fatal("等待 WebSocket usage log 写入超时")
}
var upstreamFirstPayload []byte
select {
case upstreamFirstPayload = <-upstreamPayloadCh:
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 首帧超时")
}
select {
case upstreamErr := <-upstreamErrCh:
require.NoError(t, upstreamErr)
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 结束超时")
}
return openAIResponsesWSUsageLogResult{
log: usageLog,
upstreamFirstPayload: upstreamFirstPayload,
}
}
func testStringPtr(v string) *string {
return &v
}

View File

@@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type) assert.Equal(t, "message_stop", events[1].Type)
} }
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
assert.Equal(t, 12, events[0].Usage.InputTokens)
assert.Equal(t, 4, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{

View File

@@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
} }
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState() state := NewResponsesEventToChatState()
state.Model = "gpt-4o" state.Model = "gpt-4o"

View File

@@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
return resToAnthHandleReasoningDelta(evt, state) return resToAnthHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return resToAnthHandleBlockDone(state) return resToAnthHandleBlockDone(state)
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToAnthHandleCompleted(evt, state) return resToAnthHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleReasoningDelta(evt, state) return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return nil return nil
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state) return resToChatHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
type ResponsesStreamEvent struct { type ResponsesStreamEvent struct {
Type string `json:"type"` Type string `json:"type"`
// response.created / response.completed / response.failed / response.incomplete // response.created / response.completed / response.done / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"` Response *ResponsesResponse `json:"response,omitempty"`
// response.output_item.added / response.output_item.done // response.output_item.added / response.output_item.done

View File

@@ -0,0 +1,75 @@
// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。
//
// 背景sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游
// DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions
// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到
// /v1/responses导致兼容上游 404。
//
// 本包提供基于"账号探测标记"的能力判定,配合
// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性
// 探测并落标。
//
// 设计取舍:
// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于
// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs
// - 标记缺失时默认 true即"走 Responses"),保持与重构前老代码完全一致的存量
// 账号行为("现状即证据"原则;详见
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems
package openai_compat
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
//
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
type AccountResponsesSupport int
const (
// ResponsesSupportUnknown 表示账号尚未完成能力探测extra 字段缺失)。
// 上游路由层应按"现状即证据"原则默认走 Responses保持与重构前一致。
ResponsesSupportUnknown AccountResponsesSupport = iota
// ResponsesSupportYes 探测确认上游支持 /v1/responses。
ResponsesSupportYes
// ResponsesSupportNo 探测确认上游不支持 /v1/responses应走
// /v1/chat/completions 直转路径。
ResponsesSupportNo
)
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
// 值类型为 booltrue=支持、false=不支持、键缺失=未探测。
const ExtraKeyResponsesSupported = "openai_responses_supported"
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
//
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI
func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
if extra == nil {
return ResponsesSupportUnknown
}
v, ok := extra[ExtraKeyResponsesSupported]
if !ok {
return ResponsesSupportUnknown
}
supported, ok := v.(bool)
if !ok {
return ResponsesSupportUnknown
}
if supported {
return ResponsesSupportYes
}
return ResponsesSupportNo
}
// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求
// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。
//
// 返回 true 的两种情况:
// 1. 账号已探测确认支持 Responses
// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为
//
// 仅当账号已探测且确认不支持时返回 false此时调用方应走 CC 直转路径
// (详见 internal/service/openai_gateway_chat_completions_raw.go
func ShouldUseResponsesAPI(extra map[string]any) bool {
return ResolveResponsesSupport(extra) != ResponsesSupportNo
}

View File

@@ -0,0 +1,55 @@
package openai_compat
import "testing"
func TestResolveResponsesSupport(t *testing.T) {
tests := []struct {
name string
extra map[string]any
want AccountResponsesSupport
}{
{"nil extra", nil, ResponsesSupportUnknown},
{"empty extra", map[string]any{}, ResponsesSupportUnknown},
{"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown},
{"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
{"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ResolveResponsesSupport(tc.extra)
if got != tc.want {
t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want)
}
})
}
}
func TestShouldUseResponsesAPI(t *testing.T) {
tests := []struct {
name string
extra map[string]any
want bool
}{
// 关键不变量:未探测必须返回 true保留旧行为
{"unknown defaults to true (preserve old behavior)", nil, true},
{"unknown empty defaults to true", map[string]any{}, true},
{"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true},
// 已探测:标记决定
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ShouldUseResponsesAPI(tc.extra)
if got != tc.want {
t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want)
}
})
}
}

View File

@@ -22,6 +22,34 @@ const (
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
const affiliateUserOverviewSQL = `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.aff_code,
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
ua.aff_count,
COALESCE(rebated.rebated_invitee_count, 0),
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
ua.aff_history_quota::double precision
FROM user_affiliates ua
JOIN users u ON u.id = ua.user_id
LEFT JOIN (
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
FROM user_affiliate_ledger
WHERE action = 'accrue' AND source_user_id IS NOT NULL
GROUP BY user_id
) rebated ON rebated.user_id = ua.user_id
LEFT JOIN (
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
FROM user_affiliate_ledger
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
GROUP BY user_id
) matured ON matured.user_id = ua.user_id
WHERE ua.user_id = $1
LIMIT 1`
type affiliateQueryExecer interface { type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
@@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil return bound, nil
} }
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
if amount <= 0 { if amount <= 0 {
return false, nil return false, nil
} }
@@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
if freezeHours > 0 { if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil { inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} else { } else {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} }
@@ -275,9 +303,32 @@ FROM cleared`, userID)
return err return err
} }
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { user_id,
action,
amount,
source_user_id,
balance_after,
aff_quota_after,
aff_frozen_quota_after,
aff_history_quota_after,
created_at,
updated_at
)
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
userID,
transferred,
snapshot.BalanceAfter,
snapshot.AvailableQuotaAfter,
snapshot.FrozenQuotaAfter,
snapshot.HistoryQuotaAfter,
); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err) return fmt.Errorf("insert affiliate transfer ledger: %w", err)
} }
@@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
return invitees, nil return invitees, nil
} }
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
})
total, err := queryAffiliateRecordCount(ctx, client, `
SELECT COUNT(*)
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
`+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"inviter": "inviter.email",
"invitee": "invitee.email",
"aff_code": "inviter_aff.aff_code",
"total_rebate": "total_rebate",
"created_at": "ua.created_at",
}, "ua.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ua.inviter_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ua.user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
COALESCE(inviter_aff.aff_code, ''),
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
ua.created_at
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = ua.inviter_id
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
`+where+`
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateInviteRecord, 0)
for rows.Next() {
var item service.AffiliateInviteRecord
if err := rows.Scan(
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.AffCode,
&item.TotalRebate,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN payment_orders po ON po.id = ual.source_order_id
JOIN users invitee ON invitee.id = ual.source_user_id
JOIN users inviter ON inviter.id = ual.user_id
WHERE ual.action = 'accrue'
AND ual.source_order_id IS NOT NULL`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"order": "po.id",
"inviter": "inviter.email",
"invitee": "invitee.email",
"order_amount": "po.amount",
"pay_amount": "po.pay_amount",
"rebate_amount": "ual.amount",
"payment_type": "po.payment_type",
"order_status": "po.status",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT po.id,
po.out_trade_no,
ual.user_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ual.source_user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
po.amount::double precision,
po.pay_amount::double precision,
ual.amount::double precision,
po.payment_type,
po.status,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateRebateRecord, 0)
for rows.Next() {
var item service.AffiliateRebateRecord
if err := rows.Scan(
&item.OrderID,
&item.OutTradeNo,
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.OrderAmount,
&item.PayAmount,
&item.RebateAmount,
&item.PaymentType,
&item.OrderStatus,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"u.email", "u.username", "u.id::text",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN users u ON u.id = ual.user_id
WHERE ual.action = 'transfer'`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"user": "u.email",
"amount": "ual.amount",
"balance_after": "ual.balance_after",
"available_quota_after": "ual.aff_quota_after",
"frozen_quota_after": "ual.aff_frozen_quota_after",
"history_quota_after": "ual.aff_history_quota_after",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ual.id,
ual.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ual.amount::double precision,
ual.balance_after::double precision,
ual.aff_quota_after::double precision,
ual.aff_frozen_quota_after::double precision,
ual.aff_history_quota_after::double precision,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateTransferRecord, 0)
for rows.Next() {
var item service.AffiliateTransferRecord
var balanceAfter sql.NullFloat64
var availableQuotaAfter sql.NullFloat64
var frozenQuotaAfter sql.NullFloat64
var historyQuotaAfter sql.NullFloat64
if err := rows.Scan(
&item.LedgerID,
&item.UserID,
&item.UserEmail,
&item.Username,
&item.Amount,
&balanceAfter,
&availableQuotaAfter,
&frozenQuotaAfter,
&historyQuotaAfter,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
item.SnapshotAvailable = balanceAfter.Valid &&
availableQuotaAfter.Valid &&
frozenQuotaAfter.Valid &&
historyQuotaAfter.Valid
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var overview service.AffiliateUserOverview
var customRate float64
var hasCustomRate bool
if err := rows.Scan(
&overview.UserID,
&overview.Email,
&overview.Username,
&overview.AffCode,
&customRate,
&hasCustomRate,
&overview.InvitedCount,
&overview.RebatedInviteeCount,
&overview.AvailableQuota,
&overview.HistoryQuota,
); err != nil {
return nil, err
}
if hasCustomRate {
overview.RebateRatePercent = customRate
overview.RebateRateCustom = true
}
return &overview, rows.Err()
}
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
clauses := make([]string, 0, 3)
args := make([]any, 0, 3)
if filter.StartAt != nil {
args = append(args, *filter.StartAt)
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
}
if filter.EndAt != nil {
args = append(args, *filter.EndAt)
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
}
search := strings.TrimSpace(filter.Search)
if search != "" && len(searchColumns) > 0 {
args = append(args, "%"+strings.ToLower(search)+"%")
parts := make([]string, 0, len(searchColumns))
for _, col := range searchColumns {
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
}
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
}
if len(clauses) == 0 {
return "", args
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
column := sortColumns[filter.SortBy]
if column == "" {
column = fallbackColumn
}
direction := "DESC"
if !filter.SortDesc {
direction = "ASC"
}
return "ORDER BY " + column + " " + direction + " NULLS LAST"
}
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
rows, err := client.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return 0, rows.Err()
}
var total int64
if err := rows.Scan(&total); err != nil {
return 0, err
}
return total, rows.Err()
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil { if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client()) return fn(ctx, tx.Client())
@@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
return balance, nil return balance, nil
} }
type affiliateTransferSnapshot struct {
BalanceAfter float64
AvailableQuotaAfter float64
FrozenQuotaAfter float64
HistoryQuotaAfter float64
}
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
rows, err := client.QueryContext(ctx, `
SELECT u.balance::double precision,
ua.aff_quota::double precision,
ua.aff_frozen_quota::double precision,
ua.aff_history_quota::double precision
FROM users u
JOIN user_affiliates ua ON ua.user_id = u.id
WHERE u.id = $1
LIMIT 1`, userID)
if err != nil {
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var snapshot affiliateTransferSnapshot
if err := rows.Scan(
&snapshot.BalanceAfter,
&snapshot.AvailableQuotaAfter,
&snapshot.FrozenQuotaAfter,
&snapshot.HistoryQuotaAfter,
); err != nil {
return nil, err
}
return &snapshot, rows.Err()
}
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
return &v.Float64
}
func generateAffiliateCode() (string, error) { func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength) buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil { if _, err := rand.Read(buf); err != nil {
@@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
return *v return *v
} }
func nullableInt64Arg(v *int64) any {
if v == nil {
return nil
}
return *v
}
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 // ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
// //
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索" // 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索"

View File

@@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
ledgerCount := querySingleInt(t, txCtx, client, ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount) require.Equal(t, 1, ledgerCount)
rows, err := client.QueryContext(txCtx, `
SELECT amount::double precision,
balance_after::double precision,
aff_quota_after::double precision,
aff_frozen_quota_after::double precision,
aff_history_quota_after::double precision
FROM user_affiliate_ledger
WHERE user_id = $1 AND action = 'transfer'
LIMIT 1`, u.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected transfer ledger")
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
require.NoError(t, rows.Scan(&amount, &balanceAfter, &quotaAfter, &frozenAfter, &historyAfter))
require.InDelta(t, 12.34, amount, 1e-9)
require.InDelta(t, 17.84, balanceAfter, 1e-9)
require.InDelta(t, 0.0, quotaAfter, 1e-9)
require.InDelta(t, 0.0, frozenAfter, 1e-9)
require.InDelta(t, 12.34, historyAfter, 1e-9)
} }
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the // TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
@@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter") require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
require.NoError(t, err) require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true") require.True(t, applied, "AccrueQuota must report applied=true")

View File

@@ -0,0 +1,28 @@
package repository
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
require.Contains(t, query, "frozen_until <= NOW()")
}
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
source, err := os.ReadFile("affiliate_repo.go")
require.NoError(t, err)
content := string(source)
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
require.Contains(t, content, "ual.amount::double precision")
require.Contains(t, content, "ual.balance_after::double precision")
require.NotContains(t, content, "parseAffiliateRebateAmount")
require.NotContains(t, content, `"current_balance": "u.balance"`)
}

View File

@@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates") affiliates := admin.Group("/affiliates")
{ {
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
users := affiliates.Group("/users") users := affiliates.Group("/users")
{ {
users.GET("", h.Admin.Affiliate.ListUsers) users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
} }

View File

@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -554,7 +555,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
} }
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" // 账号已被探测为不支持 Responses如 DeepSeek/Kimi 等)时,丢出明确提示。
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
// TODO实现 CC 格式的账号测试路径(需专门的 CC SSE handler
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.sendErrorAndEnd(c,
"账号已被探测为不支持 OpenAI Responses API如 DeepSeek/Kimi 等三方兼容上游),"+
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
)
}
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }

View File

@@ -0,0 +1,86 @@
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
t.Parallel()
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
older := now.Add(-2 * time.Hour)
newer := now.Add(time.Hour)
usedBy := int64(10)
redeemCodes := []RedeemCode{
{
ID: 1,
Type: RedeemTypeBalance,
Value: 8,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &now,
CreatedAt: now,
},
{
ID: 2,
Type: RedeemTypeConcurrency,
Value: 1,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &older,
CreatedAt: older,
},
}
affiliateCodes := []RedeemCode{
{
ID: -20,
Type: RedeemTypeAffiliateBalance,
Value: 3.5,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &newer,
CreatedAt: newer,
},
}
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
Page: 1,
PageSize: 2,
})
require.Len(t, got, 2)
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
require.Equal(t, RedeemTypeBalance, got[1].Type)
}
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
t.Parallel()
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
usedBy := int64(10)
at := func(hours int) *time.Time {
v := base.Add(time.Duration(hours) * time.Hour)
return &v
}
got := mergeBalanceHistoryCodes(
[]RedeemCode{
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
},
[]RedeemCode{
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
},
pagination.PaginationParams{Page: 2, PageSize: 2},
)
require.Len(t, got, 2)
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
require.Equal(t, int64(-4), got[1].ID)
}

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -973,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
if codeType == RedeemTypeAffiliateBalance {
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, 0, err
}
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, total, totalRecharged, nil
}
if codeType == "" {
return s.getAllUserBalanceHistory(ctx, userID, params)
}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
total := result.Total
// Aggregate total recharged amount (only once, regardless of type filter) // Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return codes, result.Total, totalRecharged, nil return codes, total, totalRecharged, nil
}
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
needed := params.Offset() + params.Limit()
if needed < params.Limit() {
needed = params.Limit()
}
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
}
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
if err != nil {
return nil, 0, err
}
if result != nil {
total = result.Total
}
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, err
}
total = currentTotal
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return nil, 0, nil
}
rows, err := s.entClient.QueryContext(ctx, `
SELECT id,
amount::double precision,
created_at
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'
ORDER BY created_at DESC, id DESC
OFFSET $2
LIMIT $3`, userID, params.Offset(), params.Limit())
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
codes := make([]RedeemCode, 0, params.Limit())
for rows.Next() {
var id int64
var amount float64
var createdAt time.Time
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
return nil, 0, err
}
usedBy := userID
usedAt := createdAt
codes = append(codes, RedeemCode{
ID: -id,
Code: fmt.Sprintf("AFF-%d", id),
Type: RedeemTypeAffiliateBalance,
Value: amount,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &usedAt,
CreatedAt: createdAt,
})
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
if err != nil {
return nil, 0, err
}
return codes, total, nil
}
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
rows, err := client.QueryContext(ctx, `
SELECT COUNT(*)
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'`, userID)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
var total sql.NullInt64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
if err := rows.Err(); err != nil {
return 0, err
}
if !total.Valid {
return 0, nil
}
return total.Int64, nil
}
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
sort.SliceStable(combined, func(i, j int) bool {
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
})
offset := params.Offset()
if offset >= len(combined) {
return []RedeemCode{}
}
end := offset + params.Limit()
if end > len(combined) {
end = len(combined)
}
return combined[offset:end]
}
func redeemCodeHistoryTime(code RedeemCode) time.Time {
if code.UsedAt != nil {
return *code.UsedAt
}
return code.CreatedAt
} }
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {

View File

@@ -98,7 +98,7 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
@@ -110,6 +110,10 @@ type AffiliateRepository interface {
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
} }
// AffiliateAdminFilter 列表筛选条件 // AffiliateAdminFilter 列表筛选条件
@@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
AffCount int `json:"aff_count"` AffCount int `json:"aff_count"`
} }
type AffiliateRecordFilter struct {
Search string
Page int
PageSize int
StartAt *time.Time
EndAt *time.Time
SortBy string
SortDesc bool
}
type AffiliateInviteRecord struct {
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
AffCode string `json:"aff_code"`
TotalRebate float64 `json:"total_rebate"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateRebateRecord struct {
OrderID int64 `json:"order_id"`
OutTradeNo string `json:"out_trade_no"`
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
OrderAmount float64 `json:"order_amount"`
PayAmount float64 `json:"pay_amount"`
RebateAmount float64 `json:"rebate_amount"`
PaymentType string `json:"payment_type"`
OrderStatus string `json:"order_status"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateTransferRecord struct {
LedgerID int64 `json:"ledger_id"`
UserID int64 `json:"user_id"`
UserEmail string `json:"user_email"`
Username string `json:"username"`
Amount float64 `json:"amount"`
BalanceAfter *float64 `json:"balance_after,omitempty"`
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
SnapshotAvailable bool `json:"snapshot_available"`
CurrentBalance float64 `json:"-"`
RemainingQuota float64 `json:"-"`
FrozenQuota float64 `json:"-"`
HistoryQuota float64 `json:"-"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateUserOverview struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
AffCode string `json:"aff_code"`
RebateRatePercent float64 `json:"rebate_rate_percent"`
RebateRateCustom bool `json:"-"`
InvitedCount int `json:"invited_count"`
RebatedInviteeCount int `json:"rebated_invitee_count"`
AvailableQuota float64 `json:"available_quota"`
HistoryQuota float64 `json:"history_quota"`
}
type AffiliateService struct { type AffiliateService struct {
repo AffiliateRepository repo AffiliateRepository
settingService *SettingService settingService *SettingService
@@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
} }
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
}
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
if s == nil || s.repo == nil { if s == nil || s.repo == nil {
return 0, nil return 0, nil
} }
@@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
} }
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
} }
return s.repo.ListUsersWithCustomSettings(ctx, filter) return s.repo.ListUsersWithCustomSettings(ctx, filter)
} }
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
if err != nil {
return nil, err
}
if overview != nil {
if !overview.RebateRateCustom {
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
}
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
}
return overview, nil
}
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
if filter.Page <= 0 {
filter.Page = 1
}
if filter.PageSize <= 0 {
filter.PageSize = 20
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
filter.Search = strings.TrimSpace(filter.Search)
filter.SortBy = strings.TrimSpace(filter.SortBy)
return filter
}

View File

@@ -51,10 +51,11 @@ const (
// Redeem type constants // Redeem type constants
const ( const (
RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation RedeemTypeInvitation = domain.RedeemTypeInvitation
RedeemTypeAffiliateBalance = "affiliate_balance"
) )
// PromoCode status constants // PromoCode status constants

View File

@@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
} }
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if ctx == nil {
return context.Background(), func() {}
}
if !stream { if !stream {
return ctx, func() {} return ctx, func() {}
} }
return context.WithoutCancel(ctx), func() {}
}
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil { if ctx == nil {
return context.Background(), func() {} return context.Background(), func() {}
} }

View File

@@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type upstreamContextTestKey string
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
@@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens) require.Equal(t, 7, result.usage.OutputTokens)
} }
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
upstreamCtx, release := detachUpstreamContext(parent)
defer release()
cancel()
require.NoError(t, upstreamCtx.Err())
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
}

View File

@@ -0,0 +1,149 @@
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
)
// openaiResponsesProbeTimeout 是探测请求的超时时长。
// 探测必须快速失败——超时不应阻塞账号创建/更新流程。
const openaiResponsesProbeTimeout = 8 * time.Second
// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。
// 仅作能力探测不期望响应内容质量Stream=false 减少 SSE 解析开销。
//
// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的
// 4xx/5xx如 400 invalid_request_error / 401 unauthorized / 422 等),
// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。
func openaiResponsesProbePayload(modelID string) []byte {
if strings.TrimSpace(modelID) == "" {
modelID = openai.DefaultTestModel
}
body, _ := json.Marshal(map[string]any{
"model": modelID,
"input": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{"type": "input_text", "text": "hi"},
},
},
},
"instructions": openai.DefaultInstructions,
"stream": false,
})
return body
}
// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持
// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。
//
// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。
//
// 探测策略(参见包文档 internal/pkg/openai_compat
// - 上游 404 / 405 → 不支持,写 false
// - 上游 2xx / 其他 4xx401/422/400 等)/ 5xx → 支持,写 true
// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown
// (后续请求仍按"现状即证据"默认走 Responses
//
// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。
//
// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了,
// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。
func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err)
return
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey {
// 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。
return
}
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID)
return
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err)
return
}
probeURL := buildOpenAIResponsesURL(normalizedBaseURL)
probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload("")))
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Accept", "application/json")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
// 网络层失败:不写标记,保持 unknown下次重试或由网关 fallback 处理
logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err)
return
}
defer func() {
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
_ = resp.Body.Close()
}()
supported := isResponsesEndpointSupportedByStatus(resp.StatusCode)
if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{
openai_compat.ExtraKeyResponsesSupported: supported,
}); err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err)
return
}
logger.LegacyPrintf("service.openai_probe",
"probe_done: account_id=%d base_url=%s status=%d supported=%v",
accountID, normalizedBaseURL, resp.StatusCode, supported,
)
}
// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游
// 是否暴露 /v1/responses 端点。
//
// 关键观察:第三方 OpenAI 兼容上游DeepSeek/Kimi 等)对未知端点统一返回 404
// 或 405而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段)
// 返回 400/422 等业务错误,但端点本身存在。
//
// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。
//
// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。
func isResponsesEndpointSupportedByStatus(status int) bool {
switch status {
case http.StatusNotFound, http.StatusMethodNotAllowed:
return false
}
return true
}

View File

@@ -3,13 +3,16 @@ package service
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@@ -18,6 +21,51 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAICompatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
type openAICompatBlockingReadCloser struct {
data []byte
offset int
closed chan struct{}
closeOnce sync.Once
}
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
return &openAICompatBlockingReadCloser{
data: data,
closed: make(chan struct{}),
}
}
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
if r.offset < len(r.data) {
n := copy(p, r.data[r.offset:])
r.offset += n
return n, nil
}
<-r.closed
return 0, io.EOF
}
func (r *openAICompatBlockingReadCloser) Close() error {
r.closeOnce.Do(func() {
close(r.closed)
})
return nil
}
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
t.Parallel() t.Parallel()
@@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
} }
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
}
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default") "turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
} }
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
require.NoError(t, firstErr)
require.Nil(t, firstBlocked)
meta.initFromFirstFrame(firstOut)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load())
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
meta.updateSessionRequestModel(payload)
requestModelForThisFrame := meta.requestModelForFrame(payload)
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
meta.updateFromResponseCreate(out, requestModelForThisFrame)
}
return out, blocked, policyErr
}
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
require.NoError(t, errSession)
require.Nil(t, blockedSession)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model不覆盖当前 turn metadata")
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
require.NoError(t, errCancel)
require.Nil(t, blockedCancel)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
require.NoError(t, errFlat)
require.Nil(t, blockedFlat)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
require.NoError(t, errClear)
require.Nil(t, blockedClear)
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
}
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the // TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
// "block keeps previous" semantic: when policy returns block on a // "block keeps previous" semantic: when policy returns block on a
// response.create frame, that frame is never sent upstream, so billing tier // response.create frame, that frame is never sent upstream, so billing tier

View File

@@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
return nil return nil
} }
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
counter := &openAI403CounterResetStub{} counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter) rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
rateLimitService: rateLimitSvc, billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
svc.rateLimitService = rateLimitSvc
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{}, Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage_reset_403",
Model: "gpt-5.1",
},
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2001},
Account: &Account{ID: 777, Platform: PlatformOpenAI}, Account: &Account{ID: 777, Platform: PlatformOpenAI},
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls) require.Equal(t, []int64{777}, counter.resetCalls)
require.Equal(t, 1, usageRepo.calls)
} }

View File

@@ -10,10 +10,12 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it // ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts // to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API // the response back to Chat Completions format.
// Key) go through the Responses API conversion path since the upstream only //
// exposes the /v1/responses endpoint. // 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
// 端点——这在 OAuthChatGPT 内部 API 仅支持 Responses和官方 APIKey 账号上是
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
// 这些上游普遍只支持 /v1/chat/completions无 /v1/responses 端点。
//
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
// 直转上游 /v1/chat/completions不做协议转换
// - 其他所有情况OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
func (s *OpenAIGatewayService) ForwardAsChatCompletions( func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
@@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
promptCacheKey string, promptCacheKey string,
defaultMappedModel string, defaultMappedModel string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
// 入口分流APIKey 账号 + 已探测且确认上游不支持 Responses走 CC 直转。
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
}
startTime := time.Now() startTime := time.Now()
// 1. Parse Chat Completions request // 1. Parse Chat Completions request
@@ -189,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@@ -348,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@@ -459,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@@ -467,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
@@ -496,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
chunks := apicompat.ResponsesEventToChatChunks(&event, state) chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks { if !clientDisconnected {
sse, err := apicompat.ChatChunkToSSE(chunk) for _, chunk := range chunks {
if err != nil { sse, err := apicompat.ChatChunkToSSE(chunk)
logger.L().Warn("openai chat_completions stream: failed to marshal chunk", if err != nil {
zap.Error(err), logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai chat_completions stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(chunks) > 0 { if len(chunks) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks { for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk) sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
} }
} }
// Send [DONE] sentinel // Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck if !clientDisconnected {
c.Writer.Flush() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
}
if !clientDisconnected {
c.Writer.Flush()
}
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@@ -555,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// Determine keepalive interval // Determine keepalive interval
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@@ -563,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
// No keepalive: fast synchronous path // No keepalive: fast synchronous path
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// With keepalive: goroutine + channel + select // With keepalive: goroutine + channel + select
@@ -584,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@@ -595,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@@ -605,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai chat_completions stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@@ -637,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
logger.L().Info("openai chat_completions stream: client disconnected during keepalive", logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }

View File

@@ -0,0 +1,437 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
//
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
// originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
// OpenAI 兼容上游会造成:
// - 完全忽略(多数友好厂商)——隐性污染上游统计
// - 400 "unknown parameter"(严格上游)——可见错误
//
// 这里仅放行通用 HTTP headercontent-type / authorization / accept 由上下文
// 显式设置,不依赖透传。
//
// 参见决策记录:
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
var openaiCCRawAllowedHeaders = map[string]bool{
"accept-language": true,
"user-agent": true,
}
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
// `{base_url}/v1/chat/completions`**不**做 CC↔Responses 协议转换。
//
// 适用场景account.platform=openai && account.type=apikey && 上游已被探测确认
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
//
// 与 ForwardAsChatCompletions 的关键差异:
//
// - 不调用 apicompat.ChatCompletionsToResponsesbody 仅做模型 ID 改写
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
// - 非流式响应 JSON 直接透传,仅按需提取 usage
// - 不应用 codex OAuth transformAPIKey 路径无 OAuth
// - 不注入 prompt_cache_keyOAuth 专属机制)
//
// 调用入口openai_gateway_chat_completions.go::ForwardAsChatCompletions
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse minimal fields needed for routing/billing
originalModel := gjson.GetBytes(body, "model").String()
if originalModel == "" {
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, fmt.Errorf("missing model in request")
}
clientStream := gjson.GetBytes(body, "stream").Bool()
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
serviceTier := extractOpenAIServiceTierFromBody(body)
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
// 3. Rewrite model in body (no protocol conversion)
upstreamBody := body
if upstreamModel != originalModel {
upstreamBody = ReplaceModelInBody(body, upstreamModel)
}
// 4. Apply OpenAI fast policy on the CC body
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
upstreamBody = updatedBody
if clientStream {
var usageErr error
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
if usageErr != nil {
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
}
}
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
)
// 5. Build upstream request
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
// 6. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 7. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 8. Forward response
if clientStream {
return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage包括
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
//
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage
// 让级联代理或下游计费系统也能拿到完整用量。
func (s *OpenAIGatewayService) streamRawChatCompletions(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
for scanner.Scan() {
line := scanner.Text()
if payload, ok := extractOpenAISSEDataLine(line); ok {
trimmedPayload := strings.TrimSpace(payload)
if trimmedPayload != "[DONE]" {
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
if u := extractCCStreamUsage(payload); u != nil {
usage = *u
}
if firstTokenMs == nil && !usageOnlyChunk {
elapsed := int(time.Since(startTime).Milliseconds())
firstTokenMs = &elapsed
}
}
}
if !clientDisconnected {
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
clientDisconnected = true
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
zap.Error(werr),
zap.String("request_id", requestID),
)
}
}
if line == "" {
if !clientDisconnected {
c.Writer.Flush()
}
continue
}
if !clientDisconnected {
c.Writer.Flush()
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions raw: stream read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
return body, err
}
return updated, nil
}
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
if strings.TrimSpace(payload) == "" {
return false
}
if !gjson.Get(payload, "usage").Exists() {
return false
}
choices := gjson.Get(payload, "choices")
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
}
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
// CC 协议中 usage 仅出现在末尾 chunk且仅当 include_usage 生效时),
// 但上游可能在多个 chunk 中重复——总是用最新值。
func extractCCStreamUsage(payload string) *OpenAIUsage {
usageResult := gjson.Get(payload, "usage")
if !usageResult.Exists() || !usageResult.IsObject() {
return nil
}
u := OpenAIUsage{
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
}
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
u.CacheReadInputTokens = int(cached.Int())
}
return &u
}
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
func (s *OpenAIGatewayService) bufferRawChatCompletions(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
var ccResp apicompat.ChatCompletionsResponse
var usage OpenAIUsage
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
usage = OpenAIUsage{
InputTokens: ccResp.Usage.PromptTokens,
OutputTokens: ccResp.Usage.CompletionTokens,
}
if ccResp.Usage.PromptTokensDetails != nil {
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
}
}
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Writer.Header().Set("Content-Type", ct)
} else {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(respBody)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
//
// - base 已是 /chat/completions原样返回
// - base 以 /v1 结尾:追加 /chat/completions
// - 其他情况:追加 /v1/chat/completions
//
// 与 buildOpenAIResponsesURL 是姐妹函数。
func buildOpenAIChatCompletionsURL(base string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
if strings.HasSuffix(normalized, "/chat/completions") {
return normalized
}
if strings.HasSuffix(normalized, "/v1") {
return normalized + "/chat/completions"
}
return normalized + "/v1/chat/completions"
}

View File

@@ -0,0 +1,260 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
// 已是 /chat/completions原样返回
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
// 以 /v1 结尾:追加 /chat/completions
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
// 其他情况:追加 /v1/chat/completions
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
// 第三方上游常见形式
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
// 带空白字符
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIChatCompletionsURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIResponsesURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
require.Contains(t, rec.Body.String(), `"usage"`)
require.Contains(t, rec.Body.String(), "data: [DONE]")
}
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 17, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
}
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
t.Parallel()
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
}
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
t.Parallel()
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
}
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader("toolong")),
}
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
require.Nil(t, result)
require.Equal(t, http.StatusBadGateway, rec.Code)
}
func rawChatCompletionsTestConfig() *config.Config {
return &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
}
func rawChatCompletionsTestAccount() *Account {
return &Account{
ID: 101,
Name: "raw-openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "http://upstream.example",
},
}
}

View File

@@ -1,13 +1,36 @@
package service package service
import ( import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAIChatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestNormalizeResponsesRequestServiceTier(t *testing.T) { func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel() t.Parallel()
@@ -73,3 +96,242 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Empty(t, tier) require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists()) require.False(t, gjson.GetBytes(body, "service_tier").Exists())
} }
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
}
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
}, nil }, nil
} }
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.incomplete", "response.failed":
return true
default:
return false
}
}
func isOpenAICompatDoneSentinelLine(line string) bool {
payload, ok := extractOpenAISSEDataLine(line)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
resp *http.Response,
logPrefix string,
requestID string,
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
acc := apicompat.NewBufferedResponseAccumulator()
var usage OpenAIUsage
if resp == nil || resp.Body == nil {
return nil, usage, acc, errors.New("upstream response body is nil")
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var timeoutCh <-chan time.Time
var timeoutTimer *time.Timer
resetTimeout := func() {
if streamInterval <= 0 {
return
}
if timeoutTimer == nil {
timeoutTimer = time.NewTimer(streamInterval)
timeoutCh = timeoutTimer.C
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
timeoutTimer.Reset(streamInterval)
}
stopTimeout := func() {
if timeoutTimer == nil {
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
}
resetTimeout()
defer stopTimeout()
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
go func() {
defer close(events)
for scanner.Scan() {
select {
case events <- scanEvent{line: scanner.Text()}:
case <-done:
return
}
}
if err := scanner.Err(); err != nil {
select {
case events <- scanEvent{err: err}:
case <-done:
}
}
}()
defer close(done)
for {
select {
case ev, ok := <-events:
if !ok {
return nil, usage, acc, nil
}
resetTimeout()
if ev.err != nil {
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
logger.L().Warn(logPrefix+": read error",
zap.Error(ev.err),
zap.String("request_id", requestID),
)
}
return nil, usage, acc, ev.err
}
if isOpenAICompatDoneSentinelLine(ev.line) {
return nil, usage, acc, nil
}
payload, ok := extractOpenAISSEDataLine(ev.line)
if !ok || payload == "" {
continue
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn(logPrefix+": failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
acc.ProcessEvent(&event)
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
if event.Response.Usage != nil {
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
}
return event.Response, usage, acc, nil
}
case <-timeoutCh:
_ = resp.Body.Close()
logger.L().Warn(logPrefix+": data interval timeout",
zap.String("request_id", requestID),
zap.Duration("interval", streamInterval),
)
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
}
}
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream, // handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client. // converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel // When StreamKeepaliveInterval is configured, it uses a goroutine + channel
@@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
@@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// processDataLine handles a single "data: ..." SSE line from upstream. // processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool { processDataLine := func(payload string) bool {
if firstChunk { if firstChunk {
firstChunk = false firstChunk = false
@@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
// Convert to Anthropic events // Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state) events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events { if !clientDisconnected {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) for _, evt := range events {
if err != nil { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
logger.L().Warn("openai messages stream: failed to marshal event", if err != nil {
zap.Error(err), logger.L().Warn("openai messages stream: failed to marshal event",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai messages stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(events) > 0 { if len(events) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
// finalizeStream sends any remaining Anthropic events and returns the result. // finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
for _, evt := range finalEvents { for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai messages stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
}
if !clientDisconnected {
c.Writer.Flush()
} }
c.Writer.Flush()
} }
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// ── Determine keepalive interval ── // ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// ── No keepalive: fast synchronous path (no goroutine overhead) ── // ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// ── With keepalive: goroutine + channel + select ── // ── With keepalive: goroutine + channel + select ──
@@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
@@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
// Upstream closed // Upstream closed
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai messages stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
logger.L().Info("openai messages stream: client disconnected during keepalive", logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }
@@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
}, },
}) })
} }
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
if usage == nil {
return OpenAIUsage{}
}
result := OpenAIUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
}
if usage.InputTokensDetails != nil {
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
}
return result
}

View File

@@ -186,6 +186,56 @@ func max(a, b int) int {
return b return b
} }
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage",
Usage: OpenAIUsage{},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2000},
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
require.Zero(t, usageRepo.lastLog.InputTokens)
require.Zero(t, usageRepo.lastLog.OutputTokens)
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
require.Zero(t, usageRepo.lastLog.ImageCount)
require.Zero(t, usageRepo.lastLog.InputCost)
require.Zero(t, usageRepo.lastLog.OutputCost)
require.Zero(t, usageRepo.lastLog.TotalCost)
require.Zero(t, usageRepo.lastLog.ActualCost)
require.NotNil(t, billingRepo.lastCmd)
require.Zero(t, billingRepo.lastCmd.BalanceCost)
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11) groupID := int64(11)
groupRate := 1.4 groupRate := 1.4

View File

@@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false httpInvalidEncryptedContentRetryTried := false
for { for {
// Build upstream request // Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err return nil, err
} }
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
} }
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
return nil
}
apiKey := input.APIKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account

View File

@@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
var usage OpenAIUsage var usage OpenAIUsage
imageCount := parsed.N imageCount := parsed.N
var firstTokenMs *int var firstTokenMs *int
if parsed.Stream { if parsed.Stream && isEventStreamResponse(resp.Header) {
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
usage := OpenAIUsage{} usage := OpenAIUsage{}
imageCount := 0 imageCount := 0
var firstTokenMs *int var firstTokenMs *int
var fallbackBody bytes.Buffer
fallbackBytes := int64(0)
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
seenSSEData := false
fallbackTooLarge := false
for { for {
line, err := reader.ReadBytes('\n') line, err := reader.ReadBytes('\n')
@@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
} }
flusher.Flush() flusher.Flush()
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
dataBytes := []byte(data) if data != "" && data != "[DONE]" {
mergeOpenAIUsage(&usage, dataBytes) seenSSEData = true
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { fallbackBody.Reset()
imageCount = count fallbackBytes = 0
dataBytes := []byte(data)
mergeOpenAIUsage(&usage, dataBytes)
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
imageCount = count
}
}
} else if !seenSSEData && !fallbackTooLarge {
fallbackBytes += int64(len(line))
if fallbackBytes <= fallbackLimit {
_, _ = fallbackBody.Write(line)
} else {
fallbackTooLarge = true
fallbackBody.Reset()
} }
} }
} }
@@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
return OpenAIUsage{}, 0, firstTokenMs, err return OpenAIUsage{}, 0, firstTokenMs, err
} }
} }
if !seenSSEData && fallbackBody.Len() > 0 {
body := bytes.TrimSpace(fallbackBody.Bytes())
if len(body) > 0 {
mergeOpenAIUsage(&usage, body)
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
imageCount = count
}
}
}
return usage, imageCount, firstTokenMs, nil return usage, imageCount, firstTokenMs, nil
} }
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
return count
}
if len(body) == 0 || !gjson.ValidBytes(body) {
return 0
}
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
return count
}
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
return count
}
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
return 0
}
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
return 1
}
return 0
}
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
if dst == nil { if dst == nil {
return return

View File

@@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
} }
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req_img_stream_json"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 7,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 21, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.ImageOutputTokens)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 18, result.Usage.OutputTokens)
require.Equal(t, 8, result.Usage.ImageOutputTokens)
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
}
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
} }
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
@@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
require.Contains(t, string(upstream.lastBody), `"stream":true`) require.Contains(t, string(upstream.lastBody), `"stream":true`)
} }
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 // OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
type OpenAIWSIngressHooks struct { type OpenAIWSIngressHooks struct {
BeforeTurn func(turn int) error // InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
InitialRequestModel string
BeforeTurn func(turn int) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
} }
func normalizeOpenAIWSLogValue(value string) string { func normalizeOpenAIWSLogValue(value string) string {

View File

@@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
}() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
cancelWrite() cancelWrite()
require.NoError(t, err) require.NoError(t, err)
@@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
require.Equal(t, 3, result.Usage.OutputTokens) require.Equal(t, 3, result.Usage.OutputTokens)
require.NotNil(t, result.ServiceTier) require.NotNil(t, result.ServiceTier)
require.Equal(t, "priority", *result.ServiceTier) require.Equal(t, "priority", *result.ServiceTier)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "high", *result.ReasoningEffort)
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调") t.Fatal("未收到 passthrough turn 结果回调")
} }

View File

@@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
} }
type openAIWSPassthroughUsageMeta struct {
serviceTier atomic.Pointer[string]
reasoningEffort atomic.Pointer[string]
// 仅在 client->upstream filter goroutine 中读写Load 侧通过上方原子指针同步。
sessionRequestModel string
}
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
meta := &openAIWSPassthroughUsageMeta{
sessionRequestModel: strings.TrimSpace(initialRequestModel),
}
if meta.sessionRequestModel == "" {
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
}
return meta
}
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
}
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
if m == nil {
return
}
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
m.sessionRequestModel = model
}
}
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
if m == nil {
return openAIWSPassthroughRequestModelForFrame(payload)
}
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
return model
}
return m.sessionRequestModel
}
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
}
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// silently passed through, defeating the policy on every frame after // silently passed through, defeating the policy on every frame after
// the first. // the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
initialRequestModel := ""
if hooks != nil {
initialRequestModel = hooks.InitialRequestModel
}
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil { if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
@@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
} }
firstClientMessage = updatedFirst firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报filter // 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
// usage 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当 // 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的 // 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody) // "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
@@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream // 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient // goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 service_tier // goroutine之间同步当前 turn 的 usage metadata
// extractOpenAIServiceTierFromBody 返回 *string本身是指针类型 usageMeta.initFromFirstFrame(firstClientMessage)
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account) wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil { if err != nil {
@@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated capturedSessionModel = updated
} }
usageMeta.updateSessionRequestModel(payload)
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
// Per-frame model first; if the client omits "model" on a // Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the // follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the // session-level model captured from the first frame so the
@@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
model = capturedSessionModel model = capturedSessionModel
} }
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing仅在成功non-block / non-err // 多轮 passthrough usage仅在成功non-block / non-err
// 的 response.create 帧上更新 requestServiceTierPtr,使用 // 的 response.create 帧上更新 usageMeta,使用
// filter 处理后的 payload与首帧 policy-after-extract 语义 // filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel / // - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带 // conversation.item.create / session.update 等)不携带
// per-response service_tier,不应覆盖前一轮值。 // per-response metadata,不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游billing tier 应保持 // - blocked != nil该帧不会发送上游usage metadata 应保持
// 上一轮值。 // 上一轮值。
// - policyErr != nil异常路径保持上一轮值。 // - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让 // - 不带 service_tier 的 response.create 会让
@@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// service_tier 时按 default 处理billing 应如实反映。 // service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil && if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
} }
return out, blocked, policyErr return out, blocked, policyErr
}, },
@@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens, CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
}, },
Model: turn.RequestModel, Model: turn.RequestModel,
ServiceTier: requestServiceTierPtr.Load(), ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),
@@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
}, },
Model: relayResult.RequestModel, Model: relayResult.RequestModel,
ServiceTier: requestServiceTierPtr.Load(), ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),

View File

@@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
return nil return nil
} }
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) sourceOrderID := o.ID
rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID)
if err != nil { if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(), "error": err.Error(),

View File

@@ -0,0 +1,85 @@
-- 邀请返利流水补充订单关联和转余额快照。
-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL避免把当前状态误展示为历史状态。
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL;
COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL';
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id
ON user_affiliate_ledger(source_order_id)
WHERE source_order_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup
ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at)
WHERE action = 'accrue';
-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。
-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。
WITH rebate_audits AS (
SELECT po.id AS order_id,
po.user_id AS invitee_user_id,
invitee_aff.inviter_id,
rebate_detail.rebate_amount,
pal.created_at AS audit_created_at
FROM payment_audit_logs pal
CROSS JOIN LATERAL (
SELECT substring(
pal.detail
FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)'
)::numeric AS rebate_amount
) rebate_detail
JOIN payment_orders po ON po.id::text = pal.order_id
JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id
WHERE pal.action = 'AFFILIATE_REBATE_APPLIED'
AND rebate_detail.rebate_amount IS NOT NULL
),
ranked_matches AS (
SELECT ual.id AS ledger_id,
ra.order_id,
COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count,
COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count,
ROW_NUMBER() OVER (
PARTITION BY ual.id
ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id
) AS ledger_rank
FROM rebate_audits ra
JOIN user_affiliate_ledger ual
ON ual.action = 'accrue'
AND ual.source_order_id IS NULL
AND ual.user_id = ra.inviter_id
AND ual.source_user_id = ra.invitee_user_id
AND ABS(ual.amount - ra.rebate_amount) < 0.00000001
AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes'
AND ra.audit_created_at + INTERVAL '10 minutes'
)
UPDATE user_affiliate_ledger ual
SET source_order_id = ranked_matches.order_id,
updated_at = NOW()
FROM ranked_matches
WHERE ual.id = ranked_matches.ledger_id
AND ranked_matches.order_match_count = 1
AND ranked_matches.ledger_match_count = 1
AND ranked_matches.ledger_rank = 1
AND NOT EXISTS (
SELECT 1
FROM user_affiliate_ledger existing
WHERE existing.source_order_id = ranked_matches.order_id
AND existing.action = 'accrue'
);

View File

@@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
require.Contains(t, sql, "oidc_connect_enabled") require.Contains(t, sql, "oidc_connect_enabled")
require.Contains(t, sql, "'false'") require.Contains(t, sql, "'false'")
} }
func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) {
content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql")
require.NoError(t, err)
sql := string(content)
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT")
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)")
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)")
require.Contains(t, sql, "substring(")
require.Contains(t, sql, `"rebateAmount"`)
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count")
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
require.NotContains(t, sql, "detail::jsonb")
}

View File

@@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams {
search?: string search?: string
} }
export interface ListAffiliateRecordsParams {
page?: number
page_size?: number
search?: string
start_at?: string
end_at?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
timezone?: string
}
export interface AffiliateInviteRecord {
inviter_id: number
inviter_email: string
inviter_username: string
invitee_id: number
invitee_email: string
invitee_username: string
aff_code: string
total_rebate: number
created_at: string
}
export interface AffiliateRebateRecord {
order_id: number
out_trade_no: string
inviter_id: number
inviter_email: string
inviter_username: string
invitee_id: number
invitee_email: string
invitee_username: string
order_amount: number
pay_amount: number
rebate_amount: number
payment_type: string
order_status: string
created_at: string
}
export interface AffiliateTransferRecord {
ledger_id: number
user_id: number
user_email: string
username: string
amount: number
balance_after?: number | null
available_quota_after?: number | null
frozen_quota_after?: number | null
history_quota_after?: number | null
snapshot_available: boolean
created_at: string
}
export interface AffiliateUserOverview {
user_id: number
email: string
username: string
aff_code: string
rebate_rate_percent: number
invited_count: number
rebated_invitee_count: number
available_quota: number
history_quota: number
}
export interface UpdateAffiliateUserRequest { export interface UpdateAffiliateUserRequest {
aff_code?: string aff_code?: string
aff_rebate_rate_percent?: number | null aff_rebate_rate_percent?: number | null
@@ -97,12 +163,68 @@ export async function batchSetRate(
return data return data
} }
function recordParams(params: ListAffiliateRecordsParams = {}) {
return {
page: params.page ?? 1,
page_size: params.page_size ?? 20,
search: params.search ?? '',
start_at: params.start_at || undefined,
end_at: params.end_at || undefined,
sort_by: params.sort_by || undefined,
sort_order: params.sort_order || undefined,
timezone: params.timezone || undefined,
}
}
export async function listInviteRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateInviteRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateInviteRecord>>(
'/admin/affiliates/invites',
{ params: recordParams(params) },
)
return data
}
export async function listRebateRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateRebateRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateRebateRecord>>(
'/admin/affiliates/rebates',
{ params: recordParams(params) },
)
return data
}
export async function listTransferRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateTransferRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateTransferRecord>>(
'/admin/affiliates/transfers',
{ params: recordParams(params) },
)
return data
}
export async function getUserOverview(
userId: number,
): Promise<AffiliateUserOverview> {
const { data } = await apiClient.get<AffiliateUserOverview>(
`/admin/affiliates/users/${userId}/overview`,
)
return data
}
export const affiliatesAPI = { export const affiliatesAPI = {
listUsers, listUsers,
lookupUsers, lookupUsers,
updateUserSettings, updateUserSettings,
clearUserSettings, clearUserSettings,
batchSetRate, batchSetRate,
listInviteRecords,
listRebateRecords,
listTransferRecords,
getUserOverview,
} }
export default affiliatesAPI export default affiliatesAPI

View File

@@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse<BalanceHistory
* @param id - User ID * @param id - User ID
* @param page - Page number * @param page - Page number
* @param pageSize - Items per page * @param pageSize - Items per page
* @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription) * @param type - Optional type filter (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
* @returns Paginated balance history with total_recharged * @returns Paginated balance history with total_recharged
*/ */
export async function getUserBalanceHistory( export async function getUserBalanceHistory(

View File

@@ -779,6 +779,110 @@
</div> </div>
</div> </div>
<!-- OpenAI Compact mode -->
<div v-if="allOpenAIPassthroughCapable" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div class="flex-1 pr-4">
<label
id="bulk-edit-openai-compact-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-compact-mode-enabled"
>
{{ t('admin.accounts.openai.compactMode') }}
</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.compactModeDesc') }}
</p>
</div>
<input
v-model="enableOpenAICompactMode"
id="bulk-edit-openai-compact-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-compact-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-compact-mode"
:class="!enableOpenAICompactMode && 'pointer-events-none opacity-50'"
>
<Select
v-model="openAICompactMode"
data-testid="bulk-edit-openai-compact-mode-select"
:options="openAICompactModeOptions"
aria-labelledby="bulk-edit-openai-compact-mode-label"
/>
</div>
</div>
<!-- OpenAI Compact model mapping -->
<div v-if="allOpenAIPassthroughCapable" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div class="flex-1 pr-4">
<label
id="bulk-edit-openai-compact-model-mapping-label"
class="input-label mb-0"
for="bulk-edit-openai-compact-model-mapping-enabled"
>
{{ t('admin.accounts.openai.compactModelMapping') }}
</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
</p>
</div>
<input
v-model="enableOpenAICompactModelMapping"
id="bulk-edit-openai-compact-model-mapping-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-compact-model-mapping"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-compact-model-mapping"
:class="!enableOpenAICompactModelMapping && 'pointer-events-none opacity-50'"
>
<div v-if="openAICompactModelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in openAICompactModelMappings"
:key="index"
class="flex items-center gap-2"
>
<input
v-model="mapping.from"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.fromModel')"
data-testid="bulk-edit-openai-compact-model-mapping-input"
/>
<span class="text-gray-400"></span>
<input
v-model="mapping.to"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.toModel')"
data-testid="bulk-edit-openai-compact-model-mapping-input"
/>
<button
type="button"
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
@click="removeOpenAICompactModelMapping(index)"
>
<Icon name="trash" size="sm" />
</button>
</div>
</div>
<button
type="button"
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
data-testid="bulk-edit-openai-compact-model-mapping-add"
@click="addOpenAICompactModelMapping"
>
+ {{ t('admin.accounts.addMapping') }}
</button>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) --> <!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
@@ -989,7 +1093,7 @@ import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType } from '@/types' import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType, OpenAICompactMode } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue' import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue' import Select from '@/components/common/Select.vue'
@@ -1115,6 +1219,8 @@ const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false) const enableOpenAIWSMode = ref(false)
const enableOpenAIAPIKeyWSMode = ref(false) const enableOpenAIAPIKeyWSMode = ref(false)
const enableCodexCLIOnly = ref(false) const enableCodexCLIOnly = ref(false)
const enableOpenAICompactMode = ref(false)
const enableOpenAICompactModelMapping = ref(false)
const enableRpmLimit = ref(false) const enableRpmLimit = ref(false)
// State - field values // State - field values
@@ -1140,6 +1246,8 @@ const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF) const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false) const codexCLIOnlyEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openAICompactModelMappings = ref<ModelMapping[]>([])
const rpmLimitEnabled = ref(false) const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null) const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@@ -1178,6 +1286,11 @@ const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
]) ])
const openAICompactModeOptions = computed(() => [
{ value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
])
const openAIWSModeConcurrencyHintKey = computed(() => const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value) resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
) )
@@ -1194,6 +1307,14 @@ const removeModelMapping = (index: number) => {
modelMappings.value.splice(index, 1) modelMappings.value.splice(index, 1)
} }
const addOpenAICompactModelMapping = () => {
openAICompactModelMappings.value.push({ from: '', to: '' })
}
const removeOpenAICompactModelMapping = (index: number) => {
openAICompactModelMappings.value.splice(index, 1)
}
const addPresetMapping = (from: string, to: string) => { const addPresetMapping = (from: string, to: string) => {
const exists = modelMappings.value.some((m) => m.from === from) const exists = modelMappings.value.some((m) => m.from === from)
if (exists) { if (exists) {
@@ -1262,6 +1383,10 @@ const buildModelMappingObject = (): Record<string, string> | null => {
) )
} }
const buildOpenAICompactModelMapping = (): Record<string, string> | null => {
return buildModelMappingPayload('mapping', [], openAICompactModelMappings.value)
}
const buildUpdatePayload = (): Record<string, unknown> | null => { const buildUpdatePayload = (): Record<string, unknown> | null => {
const updates: Record<string, unknown> = {} const updates: Record<string, unknown> = {}
const credentials: Record<string, unknown> = {} const credentials: Record<string, unknown> = {}
@@ -1350,10 +1475,6 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
credentialsChanged = true credentialsChanged = true
} }
if (credentialsChanged) {
updates.credentials = credentials
}
if (enableOpenAIWSMode.value) { if (enableOpenAIWSMode.value) {
const extra = ensureExtra() const extra = ensureExtra()
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
@@ -1375,6 +1496,16 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
extra.codex_cli_only = codexCLIOnlyEnabled.value extra.codex_cli_only = codexCLIOnlyEnabled.value
} }
if (enableOpenAICompactMode.value) {
const extra = ensureExtra()
extra.openai_compact_mode = openAICompactMode.value
}
if (enableOpenAICompactModelMapping.value) {
credentials.compact_model_mapping = buildOpenAICompactModelMapping() ?? {}
credentialsChanged = true
}
// RPM limit settings (写入 extra 字段) // RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) { if (enableRpmLimit.value) {
const extra = ensureExtra() const extra = ensureExtra()
@@ -1402,6 +1533,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
umqExtra.user_msg_queue_enabled = false // 清理旧字段JSONB merge umqExtra.user_msg_queue_enabled = false // 清理旧字段JSONB merge
} }
if (credentialsChanged) {
updates.credentials = credentials
}
return Object.keys(updates).length > 0 ? updates : null return Object.keys(updates).length > 0 ? updates : null
} }
@@ -1467,6 +1602,8 @@ const handleSubmit = async () => {
enableOpenAIWSMode.value || enableOpenAIWSMode.value ||
enableOpenAIAPIKeyWSMode.value || enableOpenAIAPIKeyWSMode.value ||
enableCodexCLIOnly.value || enableCodexCLIOnly.value ||
enableOpenAICompactMode.value ||
enableOpenAICompactModelMapping.value ||
enableRpmLimit.value || enableRpmLimit.value ||
userMsgQueueMode.value !== null userMsgQueueMode.value !== null
@@ -1567,6 +1704,8 @@ watch(
enableOpenAIWSMode.value = false enableOpenAIWSMode.value = false
enableOpenAIAPIKeyWSMode.value = false enableOpenAIAPIKeyWSMode.value = false
enableCodexCLIOnly.value = false enableCodexCLIOnly.value = false
enableOpenAICompactMode.value = false
enableOpenAICompactModelMapping.value = false
enableRpmLimit.value = false enableRpmLimit.value = false
// Reset all values // Reset all values
@@ -1588,6 +1727,8 @@ watch(
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false codexCLIOnlyEnabled.value = false
openAICompactMode.value = 'auto'
openAICompactModelMappings.value = []
rpmLimitEnabled.value = false rpmLimitEnabled.value = false
bulkBaseRpm.value = null bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered' bulkRpmStrategy.value = 'tiered'

View File

@@ -217,6 +217,44 @@ describe('BulkEditAccountModal', () => {
}) })
}) })
it('筛选 OpenAI 账号批量编辑应提交 Compact 模式和专属模型映射', async () => {
const wrapper = mountModal({
accountIds: [],
selectedPlatforms: [],
selectedTypes: [],
target: {
mode: 'filtered',
filters: { platform: 'openai' },
previewCount: 12,
selectedPlatforms: ['openai'],
selectedTypes: ['oauth', 'apikey']
}
})
await wrapper.get('#bulk-edit-openai-compact-mode-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-compact-mode-select"]').setValue('force_on')
await wrapper.get('#bulk-edit-openai-compact-model-mapping-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-compact-model-mapping-add"]').trigger('click')
const inputs = wrapper.findAll('[data-testid="bulk-edit-openai-compact-model-mapping-input"]')
await inputs[0].setValue('gpt-5.4')
await inputs[1].setValue('gpt-5.4-openai-compact')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
filters: { platform: 'openai' },
extra: {
openai_compact_mode: 'force_on'
},
credentials: {
compact_model_mapping: {
'gpt-5.4': 'gpt-5.4-openai-compact'
}
}
})
})
it('OpenAI 账号批量编辑可关闭自动透传', async () => { it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({ const wrapper = mountModal({
selectedPlatforms: ['openai'], selectedPlatforms: ['openai'],

View File

@@ -196,6 +196,7 @@ const totalPages = computed(() => Math.ceil(total.value / pageSize) || 1)
const typeOptions = computed(() => [ const typeOptions = computed(() => [
{ value: '', label: t('admin.users.allTypes') }, { value: '', label: t('admin.users.allTypes') },
{ value: 'balance', label: t('admin.users.typeBalance') }, { value: 'balance', label: t('admin.users.typeBalance') },
{ value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') },
{ value: 'admin_balance', label: t('admin.users.typeAdminBalance') }, { value: 'admin_balance', label: t('admin.users.typeAdminBalance') },
{ value: 'concurrency', label: t('admin.users.typeConcurrency') }, { value: 'concurrency', label: t('admin.users.typeConcurrency') },
{ value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') }, { value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') },
@@ -235,7 +236,7 @@ const loadHistory = async (page: number) => {
const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency' const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency'
// Helper: check if balance type (includes admin_balance) // Helper: check if balance type (includes admin_balance)
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance'
// Helper: check if subscription type // Helper: check if subscription type
const isSubscriptionType = (type: string) => type === 'subscription' const isSubscriptionType = (type: string) => type === 'subscription'
@@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => {
switch (item.type) { switch (item.type) {
case 'balance': case 'balance':
return t('redeem.balanceAddedRedeem') return t('redeem.balanceAddedRedeem')
case 'affiliate_balance':
return t('redeem.balanceAddedAffiliate')
case 'admin_balance': case 'admin_balance':
return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin') return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin')
case 'concurrency': case 'concurrency':

View File

@@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => {
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
{
path: '/admin/affiliates',
label: t('nav.affiliateManagement'),
icon: UsersIcon,
hideInSimpleMode: true,
expandOnly: true,
featureFlag: flagAffiliate,
children: [
{ path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon },
{ path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon },
{ path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon },
],
},
{ {
path: '/admin/orders', path: '/admin/orders',
label: t('nav.orderManagement'), label: t('nav.orderManagement'),

View File

@@ -347,6 +347,10 @@ export default {
usage: 'Usage', usage: 'Usage',
redeem: 'Redeem', redeem: 'Redeem',
affiliate: 'Affiliate Rebates', affiliate: 'Affiliate Rebates',
affiliateManagement: 'Affiliate Rebates',
affiliateInviteRecords: 'Invite Records',
affiliateRebateRecords: 'Rebate Records',
affiliateTransferRecords: 'Transfer Records',
profile: 'Profile', profile: 'Profile',
users: 'Users', users: 'Users',
groups: 'Groups', groups: 'Groups',
@@ -1046,6 +1050,7 @@ export default {
recentActivity: 'Recent Activity', recentActivity: 'Recent Activity',
historyWillAppear: 'Your redemption history will appear here', historyWillAppear: 'Your redemption history will appear here',
balanceAddedRedeem: 'Balance Added (Redeem)', balanceAddedRedeem: 'Balance Added (Redeem)',
balanceAddedAffiliate: 'Balance Added (Affiliate Transfer)',
balanceAddedAdmin: 'Balance Added (Admin)', balanceAddedAdmin: 'Balance Added (Admin)',
balanceDeductedAdmin: 'Balance Deducted (Admin)', balanceDeductedAdmin: 'Balance Deducted (Admin)',
concurrencyAddedRedeem: 'Concurrency Added (Redeem)', concurrencyAddedRedeem: 'Concurrency Added (Redeem)',
@@ -1635,6 +1640,49 @@ export default {
} }
}, },
affiliates: {
invitesDescription: 'View site-wide inviter and invitee relationships',
rebatesDescription: 'View recharge orders that generated affiliate rebates',
transfersDescription: 'View affiliate quota transfers into account balance',
errors: {
loadFailed: 'Failed to load affiliate records'
},
records: {
search: 'Search',
searchPlaceholder: 'Email, username, user ID, or order number',
startAt: 'Start date',
endAt: 'End date',
inviter: 'Inviter',
invitee: 'Invitee',
user: 'User',
affCode: 'Invite Code',
order: 'Order',
totalRebate: 'Total Rebate',
orderAmount: 'Top-up Amount',
payAmount: 'Paid Amount',
rebateAmount: 'Rebate Amount',
paymentType: 'Payment Method',
orderStatus: 'Order Status',
transferAmount: 'Transfer Amount',
balanceAfter: 'Balance After',
availableQuotaAfter: 'Available After',
frozenQuotaAfter: 'Frozen After',
historyQuotaAfter: 'Historical Rebate After',
invitedAt: 'Invited At',
rebatedAt: 'Rebated At',
transferredAt: 'Transferred At'
},
overview: {
title: 'Affiliate User Overview',
affCode: 'Invite Code',
rebateRate: 'Rebate Rate',
invitedCount: 'Invited Users',
rebatedInviteeCount: 'Rebated Invitees',
availableQuota: 'Available Quota',
historyQuota: 'Historical Rebate'
}
},
// Users // Users
users: { users: {
title: 'User Management', title: 'User Management',
@@ -1787,6 +1835,7 @@ export default {
noBalanceHistory: 'No records found for this user', noBalanceHistory: 'No records found for this user',
allTypes: 'All Types', allTypes: 'All Types',
typeBalance: 'Balance (Redeem)', typeBalance: 'Balance (Redeem)',
typeAffiliateBalance: 'Balance (Affiliate Transfer)',
typeAdminBalance: 'Balance (Admin)', typeAdminBalance: 'Balance (Admin)',
typeConcurrency: 'Concurrency (Redeem)', typeConcurrency: 'Concurrency (Redeem)',
typeAdminConcurrency: 'Concurrency (Admin)', typeAdminConcurrency: 'Concurrency (Admin)',

View File

@@ -347,6 +347,10 @@ export default {
usage: '使用记录', usage: '使用记录',
redeem: '兑换', redeem: '兑换',
affiliate: '邀请返利', affiliate: '邀请返利',
affiliateManagement: '邀请返利',
affiliateInviteRecords: '邀请记录',
affiliateRebateRecords: '返利记录',
affiliateTransferRecords: '提取记录',
profile: '个人资料', profile: '个人资料',
users: '用户管理', users: '用户管理',
groups: '分组管理', groups: '分组管理',
@@ -1050,6 +1054,7 @@ export default {
recentActivity: '最近活动', recentActivity: '最近活动',
historyWillAppear: '您的兑换历史将显示在这里', historyWillAppear: '您的兑换历史将显示在这里',
balanceAddedRedeem: '余额充值(兑换)', balanceAddedRedeem: '余额充值(兑换)',
balanceAddedAffiliate: '余额充值(返利转入)',
balanceAddedAdmin: '余额充值(管理员)', balanceAddedAdmin: '余额充值(管理员)',
balanceDeductedAdmin: '余额扣除(管理员)', balanceDeductedAdmin: '余额扣除(管理员)',
concurrencyAddedRedeem: '并发增加(兑换)', concurrencyAddedRedeem: '并发增加(兑换)',
@@ -1656,6 +1661,49 @@ export default {
} }
}, },
affiliates: {
invitesDescription: '查看全站邀请关系和被邀请用户累计返利',
rebatesDescription: '查看每一笔产生返利的充值订单',
transfersDescription: '查看返利额度转入账户余额的提取流水',
errors: {
loadFailed: '加载邀请返利记录失败'
},
records: {
search: '搜索',
searchPlaceholder: '邮箱、用户名、用户 ID、订单号',
startAt: '开始日期',
endAt: '结束日期',
inviter: '邀请人',
invitee: '被邀请人',
user: '用户',
affCode: '邀请码',
order: '订单',
totalRebate: '累计返利',
orderAmount: '充值金额',
payAmount: '支付金额',
rebateAmount: '返利金额',
paymentType: '支付方式',
orderStatus: '订单状态',
transferAmount: '提取金额',
balanceAfter: '提取后余额',
availableQuotaAfter: '提取后可提',
frozenQuotaAfter: '提取后冻结',
historyQuotaAfter: '提取后历史返利',
invitedAt: '邀请时间',
rebatedAt: '返利时间',
transferredAt: '提取时间'
},
overview: {
title: '用户返利概览',
affCode: '邀请码',
rebateRate: '返利比例',
invitedCount: '邀请人数',
rebatedInviteeCount: '已产生返利人数',
availableQuota: '可提余额',
historyQuota: '历史返利'
}
},
// Users Management // Users Management
users: { users: {
title: '用户管理', title: '用户管理',
@@ -1844,6 +1892,7 @@ export default {
noBalanceHistory: '暂无变动记录', noBalanceHistory: '暂无变动记录',
allTypes: '全部类型', allTypes: '全部类型',
typeBalance: '余额(兑换码)', typeBalance: '余额(兑换码)',
typeAffiliateBalance: '余额(返利转入)',
typeAdminBalance: '余额(管理员调整)', typeAdminBalance: '余额(管理员调整)',
typeConcurrency: '并发(兑换码)', typeConcurrency: '并发(兑换码)',
typeAdminConcurrency: '并发(管理员调整)', typeAdminConcurrency: '并发(管理员调整)',

View File

@@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.usage.description' descriptionKey: 'admin.usage.description'
} }
}, },
{
path: '/admin/affiliates',
redirect: '/admin/affiliates/invites'
},
{
path: '/admin/affiliates/invites',
name: 'AdminAffiliateInvites',
component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Invite Records',
titleKey: 'nav.affiliateInviteRecords',
descriptionKey: 'admin.affiliates.invitesDescription'
}
},
{
path: '/admin/affiliates/rebates',
name: 'AdminAffiliateRebates',
component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Rebate Records',
titleKey: 'nav.affiliateRebateRecords',
descriptionKey: 'admin.affiliates.rebatesDescription'
}
},
{
path: '/admin/affiliates/transfers',
name: 'AdminAffiliateTransfers',
component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Transfer Records',
titleKey: 'nav.affiliateTransferRecords',
descriptionKey: 'admin.affiliates.transfersDescription'
}
},
// ==================== Payment Admin Routes ==================== // ==================== Payment Admin Routes ====================

View File

@@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="invites" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>

View File

@@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="rebates" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>

View File

@@ -0,0 +1,407 @@
<template>
<AppLayout>
<TablePageLayout>
<template #filters>
<div class="flex flex-wrap items-center gap-3">
<div class="relative w-full md:w-80">
<Icon name="search" size="md" class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400" />
<input v-model="filters.search" type="text" class="input pl-10" :placeholder="t('admin.affiliates.records.searchPlaceholder')" @input="debounceLoad" />
</div>
<input v-model="filters.start_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.startAt')" @change="reloadFromFirstPage" />
<input v-model="filters.end_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.endAt')" @change="reloadFromFirstPage" />
<button class="btn btn-secondary px-2 md:px-3" :disabled="loading" :title="t('common.refresh')" @click="loadRecords">
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
</div>
</template>
<template #table>
<DataTable
:columns="columns"
:data="records"
:loading="loading"
:server-side-sort="true"
default-sort-key="created_at"
default-sort-order="desc"
:sort-storage-key="sortStorageKey"
@sort="handleSort"
>
<template #cell-inviter="{ row }">
<UserCell
:id="row.inviter_id"
:email="row.inviter_email"
:username="row.inviter_username"
:clickable="props.type !== 'transfers'"
@open="openUserOverview"
/>
</template>
<template #cell-invitee="{ row }">
<UserCell
:id="row.invitee_id"
:email="row.invitee_email"
:username="row.invitee_username"
:clickable="props.type !== 'transfers'"
@open="openUserOverview"
/>
</template>
<template #cell-user="{ row }">
<UserCell
:id="row.user_id"
:email="row.user_email"
:username="row.username"
:clickable="true"
@open="openUserOverview"
/>
</template>
<template #cell-aff_code="{ row }">
<span class="font-mono text-sm text-gray-700 dark:text-gray-300">{{ row.aff_code || '-' }}</span>
</template>
<template #cell-order="{ row }">
<div class="space-y-0.5">
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ row.order_id }}</div>
<div class="max-w-56 truncate text-sm text-gray-500 dark:text-dark-400">{{ row.out_trade_no }}</div>
</div>
</template>
<template #cell-payment_type="{ row }">
{{ t('payment.methods.' + row.payment_type, row.payment_type || '-') }}
</template>
<template #cell-order_status="{ row }">
<OrderStatusBadge :status="row.order_status" />
</template>
<template #cell-total_rebate="{ row }">
<AmountText :value="row.total_rebate" />
</template>
<template #cell-order_amount="{ row }">
<AmountText :value="row.order_amount" />
</template>
<template #cell-pay_amount="{ row }">
<span class="text-sm text-gray-900 dark:text-white">¥{{ formatAmount(row.pay_amount) }}</span>
</template>
<template #cell-rebate_amount="{ row }">
<AmountText :value="row.rebate_amount" strong />
</template>
<template #cell-amount="{ row }">
<AmountText :value="row.amount" strong />
</template>
<template #cell-balance_after="{ row }">
<NullableAmountText :value="row.balance_after" />
</template>
<template #cell-available_quota_after="{ row }">
<NullableAmountText :value="row.available_quota_after" />
</template>
<template #cell-frozen_quota_after="{ row }">
<NullableAmountText :value="row.frozen_quota_after" />
</template>
<template #cell-history_quota_after="{ row }">
<NullableAmountText :value="row.history_quota_after" />
</template>
<template #cell-created_at="{ row }">
<span class="text-sm text-gray-700 dark:text-gray-300">{{ formatDateTime(row.created_at) }}</span>
</template>
</DataTable>
</template>
<template #pagination>
<Pagination
v-if="pagination.total > 0"
:page="pagination.page"
:total="pagination.total"
:page-size="pagination.page_size"
@update:page="handlePageChange"
@update:pageSize="handlePageSizeChange"
/>
</template>
</TablePageLayout>
<BaseDialog
:show="overviewDialog"
:title="t('admin.affiliates.overview.title')"
width="normal"
@close="overviewDialog = false"
>
<div v-if="overviewLoading" class="flex justify-center py-8">
<div class="h-6 w-6 animate-spin rounded-full border-2 border-primary-500 border-t-transparent"></div>
</div>
<div v-else-if="selectedOverview" class="space-y-4">
<div class="rounded-lg border border-gray-100 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800">
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ selectedOverview.user_id }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">{{ selectedOverview.email || '-' }}</div>
<div class="mt-0.5 text-sm text-gray-500 dark:text-dark-400">{{ selectedOverview.username || '-' }}</div>
</div>
<div class="grid gap-3 sm:grid-cols-2">
<OverviewStat :label="t('admin.affiliates.overview.affCode')" :value="selectedOverview.aff_code || '-'" mono />
<OverviewStat :label="t('admin.affiliates.overview.rebateRate')" :value="formatPercent(selectedOverview.rebate_rate_percent)" />
<OverviewStat :label="t('admin.affiliates.overview.invitedCount')" :value="String(selectedOverview.invited_count)" />
<OverviewStat :label="t('admin.affiliates.overview.rebatedInviteeCount')" :value="String(selectedOverview.rebated_invitee_count)" />
<OverviewStat :label="t('admin.affiliates.overview.availableQuota')" :value="'$' + formatAmount(selectedOverview.available_quota)" />
<OverviewStat :label="t('admin.affiliates.overview.historyQuota')" :value="'$' + formatAmount(selectedOverview.history_quota)" />
</div>
</div>
</BaseDialog>
</AppLayout>
</template>
<script setup lang="ts">
import { computed, defineComponent, h, onMounted, reactive, ref, type PropType } from 'vue'
import { useI18n } from 'vue-i18n'
import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import DataTable from '@/components/common/DataTable.vue'
import Pagination from '@/components/common/Pagination.vue'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue'
import type { Column } from '@/components/common/types'
import { useAppStore } from '@/stores/app'
import { affiliatesAPI, type AffiliateInviteRecord, type AffiliateRebateRecord, type AffiliateTransferRecord, type AffiliateUserOverview, type ListAffiliateRecordsParams } from '@/api/admin/affiliates'
import type { PaginatedResponse } from '@/types'
import { extractI18nErrorMessage } from '@/utils/apiError'
import { formatDateTime as formatDisplayDateTime } from '@/utils/format'
type RecordType = 'invites' | 'rebates' | 'transfers'
type AffiliateRecord = AffiliateInviteRecord | AffiliateRebateRecord | AffiliateTransferRecord
const props = defineProps<{
type: RecordType
}>()
const { t } = useI18n()
const appStore = useAppStore()
const loading = ref(false)
const records = ref<AffiliateRecord[]>([])
const filters = reactive({ search: '', start_at: '', end_at: '' })
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
const overviewDialog = ref(false)
const overviewLoading = ref(false)
const selectedOverview = ref<AffiliateUserOverview | null>(null)
let debounceTimer: ReturnType<typeof setTimeout> | null = null
const columns = computed<Column[]>(() => {
if (props.type === 'invites') {
return [
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
{ key: 'aff_code', label: t('admin.affiliates.records.affCode'), sortable: true },
{ key: 'total_rebate', label: t('admin.affiliates.records.totalRebate'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.invitedAt'), sortable: true },
]
}
if (props.type === 'rebates') {
return [
{ key: 'order', label: t('admin.affiliates.records.order'), sortable: true },
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
{ key: 'order_amount', label: t('admin.affiliates.records.orderAmount'), sortable: true },
{ key: 'pay_amount', label: t('admin.affiliates.records.payAmount'), sortable: true },
{ key: 'rebate_amount', label: t('admin.affiliates.records.rebateAmount') },
{ key: 'payment_type', label: t('admin.affiliates.records.paymentType'), sortable: true },
{ key: 'order_status', label: t('admin.affiliates.records.orderStatus'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.rebatedAt'), sortable: true },
]
}
return [
{ key: 'user', label: t('admin.affiliates.records.user'), sortable: true },
{ key: 'amount', label: t('admin.affiliates.records.transferAmount'), sortable: true },
{ key: 'balance_after', label: t('admin.affiliates.records.balanceAfter'), sortable: true },
{ key: 'available_quota_after', label: t('admin.affiliates.records.availableQuotaAfter'), sortable: true },
{ key: 'frozen_quota_after', label: t('admin.affiliates.records.frozenQuotaAfter'), sortable: true },
{ key: 'history_quota_after', label: t('admin.affiliates.records.historyQuotaAfter'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.transferredAt'), sortable: true },
]
})
const sortStorageKey = computed(() => `admin-affiliate-${props.type}-table-sort`)
function loadInitialSortState(): { sort_by: string; sort_order: 'asc' | 'desc' } {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
try {
const raw = localStorage.getItem(sortStorageKey.value)
if (!raw) return fallback
const parsed = JSON.parse(raw) as { key?: string; order?: string }
const key = typeof parsed.key === 'string' ? parsed.key : ''
if (!columns.value.some((column) => column.key === key && column.sortable)) return fallback
return {
sort_by: key,
sort_order: parsed.order === 'asc' ? 'asc' : 'desc',
}
} catch {
return fallback
}
}
const sortState = reactive(loadInitialSortState())
function userTimezone(): string {
try {
return Intl.DateTimeFormat().resolvedOptions().timeZone
} catch {
return 'UTC'
}
}
function buildParams(): ListAffiliateRecordsParams {
return {
page: pagination.page,
page_size: pagination.page_size,
search: filters.search.trim() || undefined,
start_at: filters.start_at || undefined,
end_at: filters.end_at || undefined,
sort_by: sortState.sort_by,
sort_order: sortState.sort_order,
timezone: userTimezone(),
}
}
async function fetchRecords(params: ListAffiliateRecordsParams): Promise<PaginatedResponse<AffiliateRecord>> {
if (props.type === 'invites') {
return affiliatesAPI.listInviteRecords(params)
}
if (props.type === 'rebates') {
return affiliatesAPI.listRebateRecords(params)
}
return affiliatesAPI.listTransferRecords(params)
}
async function loadRecords() {
loading.value = true
try {
const res = await fetchRecords(buildParams())
records.value = res.items || []
pagination.total = res.total || 0
} catch (error) {
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
} finally {
loading.value = false
}
}
function debounceLoad() {
if (debounceTimer) clearTimeout(debounceTimer)
debounceTimer = setTimeout(() => reloadFromFirstPage(), 300)
}
function reloadFromFirstPage() {
pagination.page = 1
void loadRecords()
}
function handlePageChange(page: number) {
pagination.page = page
void loadRecords()
}
function handlePageSizeChange(size: number) {
pagination.page_size = size
pagination.page = 1
void loadRecords()
}
function handleSort(key: string, order: 'asc' | 'desc') {
sortState.sort_by = key
sortState.sort_order = order
pagination.page = 1
void loadRecords()
}
function formatAmount(value: number | null | undefined): string {
return Number(value || 0).toFixed(2)
}
function formatPercent(value: number | null | undefined): string {
const rounded = Math.round(Number(value || 0) * 100) / 100
return `${Number.isInteger(rounded) ? rounded.toString() : rounded.toString()}%`
}
function formatDateTime(value: string | null | undefined): string {
return value ? formatDisplayDateTime(value) : '-'
}
async function openUserOverview(userId: number) {
if (!userId) return
overviewDialog.value = true
overviewLoading.value = true
selectedOverview.value = null
try {
selectedOverview.value = await affiliatesAPI.getUserOverview(userId)
} catch (error) {
overviewDialog.value = false
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
} finally {
overviewLoading.value = false
}
}
const UserCell = defineComponent({
props: {
id: { type: Number, required: true },
email: { type: String, default: '' },
username: { type: String, default: '' },
clickable: { type: Boolean, default: false },
},
emits: ['open'],
setup(cellProps, { emit }) {
return () => h('div', { class: 'space-y-0.5' }, [
h('div', { class: 'font-mono text-sm text-gray-900 dark:text-white' }, `#${cellProps.id}`),
h(cellProps.clickable ? 'button' : 'div', {
class: cellProps.clickable
? 'max-w-56 truncate text-left text-sm font-medium text-primary-600 hover:text-primary-700 hover:underline dark:text-primary-400 dark:hover:text-primary-300'
: 'max-w-56 truncate text-sm text-gray-700 dark:text-gray-300',
type: cellProps.clickable ? 'button' : undefined,
onClick: cellProps.clickable ? () => emit('open', cellProps.id) : undefined,
}, cellProps.email || '-'),
h('div', { class: 'max-w-56 truncate text-sm text-gray-500 dark:text-dark-400' }, cellProps.username || '-'),
])
},
})
const AmountText = defineComponent({
props: {
value: { type: Number, default: 0 },
strong: { type: Boolean, default: false },
},
setup(amountProps) {
return () => h('span', {
class: amountProps.strong
? 'text-sm font-semibold text-emerald-600 dark:text-emerald-400'
: 'text-sm text-gray-900 dark:text-white',
}, `$${formatAmount(amountProps.value)}`)
},
})
const NullableAmountText = defineComponent({
props: {
value: { type: Number as PropType<number | null | undefined>, default: null },
},
setup(amountProps) {
return () => {
const value = amountProps.value
if (value === null || value === undefined) {
return h('span', { class: 'text-sm text-gray-400 dark:text-dark-500' }, '-')
}
return h(AmountText, { value })
}
},
})
const OverviewStat = defineComponent({
props: {
label: { type: String, required: true },
value: { type: String, required: true },
mono: { type: Boolean, default: false },
},
setup(statProps) {
return () => h('div', { class: 'rounded-lg border border-gray-100 bg-white p-3 dark:border-dark-700 dark:bg-dark-900' }, [
h('div', { class: 'text-sm text-gray-500 dark:text-dark-400' }, statProps.label),
h('div', {
class: statProps.mono
? 'mt-1 font-mono text-base font-semibold text-gray-900 dark:text-white'
: 'mt-1 text-base font-semibold text-gray-900 dark:text-white',
}, statProps.value),
])
},
})
onMounted(() => {
void loadRecords()
})
</script>

View File

@@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="transfers" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>