From a6764e82f252c8f1d9699256ac30f6774a46cb1a Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 19 Mar 2026 16:44:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20OAuth/SetupToken=20?= =?UTF-8?q?=E8=BD=AC=E5=8F=91=E8=AF=B7=E6=B1=82=E4=BD=93=E9=87=8D=E6=8E=92?= =?UTF-8?q?=E5=B9=B6=E5=A2=9E=E5=8A=A0=E8=B0=83=E8=AF=95=E5=BC=80=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pkg/antigravity/request_transformer.go | 23 +- .../antigravity/request_transformer_test.go | 51 ++ ...teway_anthropic_apikey_passthrough_test.go | 77 ++ .../service/gateway_body_order_test.go | 72 ++ .../service/gateway_debug_env_test.go | 34 + backend/internal/service/gateway_service.go | 717 ++++++++++-------- backend/internal/service/identity_service.go | 74 +- .../service/identity_service_order_test.go | 82 ++ 8 files changed, 742 insertions(+), 388 deletions(-) create mode 100644 backend/internal/service/gateway_body_order_test.go create mode 100644 backend/internal/service/gateway_debug_env_test.go create mode 100644 backend/internal/service/identity_service_order_test.go diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 55cdd786..1b45e507 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string { return "" } -// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 -var systemBlockFilterPrefixes = []string{ - "x-anthropic-billing-header", -} - -// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串 -func filterSystemBlockByPrefix(text string) string { - for _, prefix := range systemBlockFilterPrefixes { - if strings.HasPrefix(text, prefix) { - return "" - } - } - return text -} - // buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart @@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词和黑名单前缀 - filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr)) + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(sysStr) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词和黑名单前缀 - filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text)) + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(block.Text) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index f267e0e1..9e46295a 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -2,7 +2,10 @@ package antigravity import ( "encoding/json" + "strings" "testing" + + "github.com/stretchr/testify/require" ) // TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 @@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { }) } } + +func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) { + tests := []struct { + name string + system json.RawMessage + }{ + { + name: "system array", + system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`), + }, + { + name: "system string", + system: json.RawMessage(`"x-anthropic-billing-header keep"`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claudeReq := &ClaudeRequest{ + Model: "claude-3-5-sonnet-latest", + System: tt.system, + Messages: []ClaudeMessage{ + { + Role: "user", + Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), + }, + }, + } + + body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions()) + require.NoError(t, err) + + var req V1InternalRequest + require.NoError(t, json.Unmarshal(body, &req)) + require.NotNil(t, req.Request.SystemInstruction) + + found := false + for _, part := range req.Request.SystemInstruction.Parts { + if strings.Contains(part.Text, "x-anthropic-billing-header keep") { + found = true + break + } + } + + require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容") + }) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index c534a9b7..a01dd02a 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") } +func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + body string + }{ + { + name: "system array", + body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`, + }, + { + name: "system string", + body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic) + require.NoError(t, err) + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-oauth-preserve"}, + }, + Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)), + }, + } + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + } + + account := &Account{ + ID: 301, + Name: "anthropic-oauth-preserve", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization")) + require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth) + + system := gjson.GetBytes(upstream.lastBody, "system") + require.True(t, system.Exists()) + require.Contains(t, system.Raw, "x-anthropic-billing-header keep") + }) + } +} + func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go new file mode 100644 index 00000000..641522f0 --- /dev/null +++ b/backend/internal/service/gateway_body_order_test.go @@ -0,0 +1,72 @@ +package service + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/stretchr/testify/require" +) + +func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) { + t.Helper() + + last := -1 + for _, token := range tokens { + pos := strings.Index(body, token) + require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body) + require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body) + last = pos + } +} + +func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`) + + result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022") + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`) + require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`) +} + +func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`) + + result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{ + injectMetadata: true, + metadataUserID: "user-1", + }) + resultStr := string(result) + + require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID) + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`) + require.NotContains(t, resultStr, `"temperature"`) + require.NotContains(t, resultStr, `"tool_choice"`) + require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`) + require.Contains(t, resultStr, `"tools":[]`) + require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`) +} + +func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`) + + result := injectClaudeCodePrompt(body, []any{ + map[string]any{"id": "block-1", "type": "text", "text": "Custom"}, + }) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) + require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`) +} + +func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`) + + result := enforceCacheControlLimit(body) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) + require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`)) +} diff --git a/backend/internal/service/gateway_debug_env_test.go b/backend/internal/service/gateway_debug_env_test.go new file mode 100644 index 00000000..4f48dc70 --- /dev/null +++ b/backend/internal/service/gateway_debug_env_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestDebugGatewayBodyLoggingEnabled(t *testing.T) { + t.Run("default disabled", func(t *testing.T) { + t.Setenv(debugGatewayBodyEnv, "") + if debugGatewayBodyLoggingEnabled() { + t.Fatalf("expected debug gateway body logging to be disabled by default") + } + }) + + t.Run("enabled with true-like values", func(t *testing.T) { + for _, value := range []string{"1", "true", "TRUE", "yes", "on"} { + t.Run(value, func(t *testing.T) { + t.Setenv(debugGatewayBodyEnv, value) + if !debugGatewayBodyLoggingEnabled() { + t.Fatalf("expected debug gateway body logging to be enabled for %q", value) + } + }) + } + }) + + t.Run("disabled with other values", func(t *testing.T) { + for _, value := range []string{"0", "false", "off", "debug"} { + t.Run(value, func(t *testing.T) { + t.Setenv(debugGatewayBodyEnv, value) + if debugGatewayBodyLoggingEnabled() { + t.Fatalf("expected debug gateway body logging to be disabled for %q", value) + } + }) + } + }) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 7e962f7f..e23d24de 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -51,6 +51,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second postUsageBillingTimeout = 15 * time.Second + debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY" ) const ( @@ -339,12 +340,6 @@ var ( } ) -// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 -// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除 -var systemBlockFilterPrefixes = []string{ - "x-anthropic-billing-header", -} - // ErrNoAvailableAccounts 表示没有可用的账号 var ErrNoAvailableAccounts = errors.New("no available accounts") @@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string { return strconv.FormatUint(h, 36) } +type anthropicCacheControlPayload struct { + Type string `json:"type"` +} + +type anthropicSystemTextBlockPayload struct { + Type string `json:"type"` + Text string `json:"text"` + CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"` +} + +type anthropicMetadataPayload struct { + UserID string `json:"user_id"` +} + // replaceModelInBody 替换请求体中的model字段 -// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 +// 优先使用定点修改,尽量保持客户端原始字段顺序。 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - var req map[string]json.RawMessage - if err := json.Unmarshal(body, &req); err != nil { + if len(body) == 0 { return body } - // 只序列化 model 字段 - modelBytes, err := json.Marshal(newModel) - if err != nil { + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { return body } - req["model"] = modelBytes - newBody, err := json.Marshal(req) + newBody, err := sjson.SetBytes(body, "model", newModel) if err != nil { return body } @@ -884,24 +889,146 @@ func sanitizeSystemText(text string) string { return text } -func stripCacheControlFromSystemBlocks(system any) bool { - blocks, ok := system.([]any) - if !ok { - return false +func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) { + block := anthropicSystemTextBlockPayload{ + Type: "text", + Text: text, } - changed := false - for _, item := range blocks { - block, ok := item.(map[string]any) - if !ok { - continue - } - if _, exists := block["cache_control"]; !exists { - continue - } - delete(block, "cache_control") - changed = true + if includeCacheControl { + block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"} } - return changed + return json.Marshal(block) +} + +func marshalAnthropicMetadata(userID string) ([]byte, error) { + return json.Marshal(anthropicMetadataPayload{UserID: userID}) +} + +func buildJSONArrayRaw(items [][]byte) []byte { + if len(items) == 0 { + return []byte("[]") + } + + total := 2 + for _, item := range items { + total += len(item) + } + total += len(items) - 1 + + buf := make([]byte, 0, total) + buf = append(buf, '[') + for i, item := range items { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, item...) + } + buf = append(buf, ']') + return buf +} + +func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) { + next, err := sjson.SetBytes(body, path, value) + if err != nil { + return body, false + } + return next, true +} + +func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) { + next, err := sjson.SetRawBytes(body, path, raw) + if err != nil { + return body, false + } + return next, true +} + +func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) { + next, err := sjson.DeleteBytes(body, path) + if err != nil { + return body, false + } + return next, true +} + +func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body, false + } + + out := body + modified := false + + switch { + case sys.Type == gjson.String: + sanitized := sanitizeSystemText(sys.String()) + if sanitized != sys.String() { + if next, ok := setJSONValueBytes(out, "system", sanitized); ok { + out = next + modified = true + } + } + case sys.IsArray(): + index := 0 + sys.ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "text" { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + sanitized := sanitizeSystemText(text) + if sanitized != text { + if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok { + out = next + modified = true + } + } + } + } + + if opts.stripSystemCacheControl && item.Get("cache_control").Exists() { + if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok { + out = next + modified = true + } + } + + index++ + return true + }) + } + + return out, modified +} + +func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { + if strings.TrimSpace(userID) == "" { + return body, false + } + + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) + } + + trimmedRaw := strings.TrimSpace(metadata.Raw) + if strings.HasPrefix(trimmedRaw, "{") { + existing := metadata.Get("user_id") + if existing.Exists() && existing.Type == gjson.String && existing.String() != "" { + return body, false + } + return setJSONValueBytes(body, "metadata.user_id", userID) + } + + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) } func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { @@ -909,96 +1036,59 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu return body, modelID } - // 解析为 map[string]any 用于修改字段 - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return body, modelID - } - + out := body modified := false - if system, ok := req["system"]; ok { - switch v := system.(type) { - case string: - sanitized := sanitizeSystemText(v) - if sanitized != v { - req["system"] = sanitized - modified = true - } - case []any: - for _, item := range v { - block, ok := item.(map[string]any) - if !ok { - continue - } - if blockType, _ := block["type"].(string); blockType != "text" { - continue - } - text, ok := block["text"].(string) - if !ok || text == "" { - continue - } - sanitized := sanitizeSystemText(text) - if sanitized != text { - block["text"] = sanitized - modified = true - } - } - } + if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { + out = next + modified = true } - if rawModel, ok := req["model"].(string); ok { - normalized := claude.NormalizeModelID(rawModel) - if normalized != rawModel { - req["model"] = normalized + rawModel := gjson.GetBytes(out, "model") + if rawModel.Exists() && rawModel.Type == gjson.String { + normalized := claude.NormalizeModelID(rawModel.String()) + if normalized != rawModel.String() { + if next, ok := setJSONValueBytes(out, "model", normalized); ok { + out = next + modified = true + } modelID = normalized - modified = true } } // 确保 tools 字段存在(即使为空数组) - if _, exists := req["tools"]; !exists { - req["tools"] = []any{} - modified = true - } - - if opts.stripSystemCacheControl { - if system, ok := req["system"]; ok { - _ = stripCacheControlFromSystemBlocks(system) + if !gjson.GetBytes(out, "tools").Exists() { + if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok { + out = next modified = true } } if opts.injectMetadata && opts.metadataUserID != "" { - metadata, ok := req["metadata"].(map[string]any) - if !ok { - metadata = map[string]any{} - req["metadata"] = metadata - } - if existing, ok := metadata["user_id"].(string); !ok || existing == "" { - metadata["user_id"] = opts.metadataUserID + if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed { + out = next modified = true } } - if _, hasTemp := req["temperature"]; hasTemp { - delete(req, "temperature") - modified = true + if gjson.GetBytes(out, "temperature").Exists() { + if next, ok := deleteJSONPathBytes(out, "temperature"); ok { + out = next + modified = true + } } - if _, hasChoice := req["tool_choice"]; hasChoice { - delete(req, "tool_choice") - modified = true + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } } if !modified { return body, modelID } - newBody, err := json.Marshal(req) - if err != nil { - return body, modelID - } - return newBody, modelID + return out, modelID } func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { @@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool { return false } -// matchesFilterPrefix 检查文本是否匹配任一过滤前缀 -func matchesFilterPrefix(text string) bool { - for _, prefix := range systemBlockFilterPrefixes { - if strings.HasPrefix(text, prefix) { - return true - } - } - return false -} - -// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素 -// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system) -func filterSystemBlocksByPrefix(body []byte) []byte { - sys := gjson.GetBytes(body, "system") - if !sys.Exists() { - return body - } - - switch { - case sys.Type == gjson.String: - if matchesFilterPrefix(sys.Str) { - result, err := sjson.DeleteBytes(body, "system") - if err != nil { - return body - } - return result - } - case sys.IsArray(): - var parsed []any - if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil { - return body - } - filtered := make([]any, 0, len(parsed)) - changed := false - for _, item := range parsed { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) { - changed = true - continue - } - } - filtered = append(filtered, item) - } - if changed { - result, err := sjson.SetBytes(body, "system", filtered) - if err != nil { - return body - } - return result - } - } - return body -} - // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { - claudeCodeBlock := map[string]any{ - "type": "text", - "text": claudeCodeSystemPrompt, - "cache_control": map[string]string{"type": "ephemeral"}, + claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) + return body } // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code // banner, it also prefixes the next system instruction with the same banner plus // a blank line. This helps when upstream concatenates system instructions. claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) - var newSystem []any + var items [][]byte switch v := system.(type) { case nil: - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} case string: // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} } else { // Mirror opencode behavior: keep the banner as a separate system entry, // but also prefix the next system text with the banner. @@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { if !strings.HasPrefix(v, claudeCodePrefix) { merged = claudeCodePrefix + "\n\n" + v } - newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} + nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false) + if buildErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr) + return body + } + items = [][]byte{claudeCodeBlock, nextBlock} } case []any: - newSystem = make([]any, 0, len(v)+1) - newSystem = append(newSystem, claudeCodeBlock) + items = make([][]byte, 0, len(v)+1) + items = append(items, claudeCodeBlock) prefixedNext := false - for _, item := range v { - if m, ok := item.(map[string]any); ok { + systemResult := gjson.GetBytes(body, "system") + if systemResult.IsArray() { + systemResult.ForEach(func(_, item gjson.Result) bool { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String && + strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) { + return true + } + + raw := []byte(item.Raw) + // Prefix the first subsequent text system block once. + if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text) + if setErr == nil { + raw = next + prefixedNext = true + } + } + } + items = append(items, raw) + return true + }) + } else { + for _, item := range v { + m, ok := item.(map[string]any) + if !ok { + raw, marshalErr := json.Marshal(item) + if marshalErr == nil { + items = append(items, raw) + } + continue + } if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { continue } - // Prefix the first subsequent text system block once. if !prefixedNext { if blockType, _ := m["type"].(string); blockType == "text" { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { @@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { } } } + raw, marshalErr := json.Marshal(m) + if marshalErr == nil { + items = append(items, raw) + } } - newSystem = append(newSystem, item) } default: - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} } - result, err := sjson.SetBytes(body, "system", newSystem) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) + result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt") return body } return result } +type cacheControlPath struct { + path string + log string +} + +func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) { + system := gjson.GetBytes(body, "system") + if system.IsArray() { + sysIndex := 0 + system.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("system.%d.cache_control", sysIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: "[Warning] Removed illegal cache_control from thinking block in system", + }) + } else { + systemPaths = append(systemPaths, path) + } + } + sysIndex++ + return true + }) + } + + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIndex := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + contentIndex := 0 + content.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex), + }) + } else { + messagePaths = append(messagePaths, path) + } + } + contentIndex++ + return true + }) + } + msgIndex++ + return true + }) + } + + return invalidThinking, messagePaths, systemPaths +} + // enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) // 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 func enforceCacheControlLimit(body []byte) []byte { - var data map[string]any - if err := json.Unmarshal(body, &data); err != nil { + if len(body) == 0 { return body } - // 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) - removeCacheControlFromThinkingBlocks(data) + invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body) + out := body + modified := false - // 计算当前 cache_control 块数量 - count := countCacheControlBlocks(data) + // 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) + for _, item := range invalidThinking { + if !gjson.GetBytes(out, item.path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, item.path) + if !ok { + continue + } + out = next + modified = true + logger.LegacyPrintf("service.gateway", "%s", item.log) + } + + count := len(messagePaths) + len(systemPaths) if count <= maxCacheControlBlocks { + if modified { + return out + } return body } // 超限:优先从 messages 中移除,再从 system 中移除 - for count > maxCacheControlBlocks { - if removeCacheControlFromMessages(data) { - count-- + remaining := count - maxCacheControlBlocks + for _, path := range messagePaths { + if remaining <= 0 { + break + } + if !gjson.GetBytes(out, path).Exists() { continue } - if removeCacheControlFromSystem(data) { - count-- - continue - } - break - } - - result, err := json.Marshal(data) - if err != nil { - return body - } - return result -} - -// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量 -// 注意:thinking 块不支持 cache_control,统计时跳过 -func countCacheControlBlocks(data map[string]any) int { - count := 0 - - // 统计 system 中的块 - if system, ok := data["system"].([]any); ok { - for _, item := range system { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - count++ - } - } - } - } - - // 统计 messages 中的块 - if messages, ok := data["messages"].([]any); ok { - for _, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if content, ok := msgMap["content"].([]any); ok { - for _, item := range content { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - count++ - } - } - } - } - } - } - } - - return count -} - -// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始) -// 返回 true 表示成功移除,false 表示没有可移除的 -// 注意:跳过 thinking 块(它不支持 cache_control) -func removeCacheControlFromMessages(data map[string]any) bool { - messages, ok := data["messages"].([]any) - if !ok { - return false - } - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + next, ok := deleteJSONPathBytes(out, path) if !ok { continue } - content, ok := msgMap["content"].([]any) + out = next + modified = true + remaining-- + } + + for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- { + path := systemPaths[i] + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) if !ok { continue } - for _, item := range content { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - return true - } - } - } - } - return false -} - -// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt) -// 返回 true 表示成功移除,false 表示没有可移除的 -// 注意:跳过 thinking 块(它不支持 cache_control) -func removeCacheControlFromSystem(data map[string]any) bool { - system, ok := data["system"].([]any) - if !ok { - return false + out = next + modified = true + remaining-- } - // 从尾部开始移除,保护开头注入的 Claude Code prompt - for i := len(system) - 1; i >= 0; i-- { - if m, ok := system[i].(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - return true - } - } - } - return false -} - -// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control -// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段 -func removeCacheControlFromThinkingBlocks(data map[string]any) { - // 清理 system 中的 thinking 块 - if system, ok := data["system"].([]any); ok { - for _, item := range system { - if m, ok := item.(map[string]any); ok { - if blockType, _ := m["type"].(string); blockType == "thinking" { - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system") - } - } - } - } - } - - // 清理 messages 中的 thinking 块 - if messages, ok := data["messages"].([]any); ok { - for msgIdx, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if content, ok := msgMap["content"].([]any); ok { - for contentIdx, item := range content { - if m, ok := item.(map[string]any); ok { - if blockType, _ := m["type"].(string); blockType == "thinking" { - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) - } - } - } - } - } - } - } + if modified { + return out } + return body } // Forward 转发请求到Claude API @@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqStream := parsed.Stream originalModel := reqModel + // === DEBUG: 打印客户端原始请求 body === + debugLogRequestBody("CLIENT_ORIGINAL", body) + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } - // OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据) - // 放在 inject/normalize 之后,确保不会被覆盖 - if account.IsOAuth() { - body = filterSystemBlocksByPrefix(body) - } - // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) @@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // === DEBUG: 打印转发给上游的 body(metadata 已重写) === + debugLogRequestBody("UPSTREAM_FORWARD", body) + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool { usage["cache_read_input_tokens"] = cached return true } + +func debugGatewayBodyLoggingEnabled() bool { + raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv)) + if raw == "" { + return false + } + + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。 +// 默认关闭,仅在设置环境变量时启用: +// +// SUB2API_DEBUG_GATEWAY_BODY=1 +func debugLogRequestBody(tag string, body []byte) { + if !debugGatewayBodyLoggingEnabled() { + return + } + + if len(body) == 0 { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag) + return + } + + // 提取 metadata 字段完整打印 + metadataResult := gjson.GetBytes(body, "metadata") + if metadataResult.Exists() { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw) + } else { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag) + } + + // 全量打印 body + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body)) +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 8d464a8b..428f5bfd 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" "log/slog" "net/http" @@ -15,6 +14,8 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // 预编译正则表达式(避免每次调用重新编译) @@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return body, nil } - // 使用 RawMessage 保留其他字段的原始字节 - var reqMap map[string]json.RawMessage - if err := json.Unmarshal(body, &reqMap); err != nil { + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + return body, nil + } + if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") { return body, nil } - // 解析 metadata 字段 - metadataRaw, ok := reqMap["metadata"] - if !ok { + userIDResult := metadata.Get("user_id") + if !userIDResult.Exists() || userIDResult.Type != gjson.String { return body, nil } - - var metadata map[string]any - if err := json.Unmarshal(metadataRaw, &metadata); err != nil { - return body, nil - } - - userID, ok := metadata["user_id"].(string) - if !ok || userID == "" { + userID := userIDResult.String() + if userID == "" { return body, nil } @@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // 根据客户端版本选择输出格式 version := ExtractCLIVersion(fingerprintUA) newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version) + if newUserID == userID { + return body, nil + } - metadata["user_id"] = newUserID - - // 只重新序列化 metadata 字段 - newMetadataRaw, err := json.Marshal(metadata) + newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID) if err != nil { return body, nil } - reqMap["metadata"] = newMetadataRaw - - return json.Marshal(reqMap) + return newBody, nil } // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 @@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 使用 RawMessage 保留其他字段的原始字节 - var reqMap map[string]json.RawMessage - if err := json.Unmarshal(newBody, &reqMap); err != nil { + metadata := gjson.GetBytes(newBody, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + return newBody, nil + } + if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") { return newBody, nil } - // 解析 metadata 字段 - metadataRaw, ok := reqMap["metadata"] - if !ok { + userIDResult := metadata.Get("user_id") + if !userIDResult.Exists() || userIDResult.Type != gjson.String { return newBody, nil } - - var metadata map[string]any - if err := json.Unmarshal(metadataRaw, &metadata); err != nil { - return newBody, nil - } - - userID, ok := metadata["user_id"].(string) - if !ok || userID == "" { + userID := userIDResult.String() + if userID == "" { return newBody, nil } @@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b "after", newUserID, ) - metadata["user_id"] = newUserID - - // 只重新序列化 metadata 字段 - newMetadataRaw, marshalErr := json.Marshal(metadata) - if marshalErr != nil { + if newUserID == userID { return newBody, nil } - reqMap["metadata"] = newMetadataRaw - return json.Marshal(reqMap) + maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID) + if setErr != nil { + return newBody, nil + } + return maskedBody, nil } // generateRandomUUID 生成随机 UUID v4 格式字符串 diff --git a/backend/internal/service/identity_service_order_test.go b/backend/internal/service/identity_service_order_test.go new file mode 100644 index 00000000..d1e12274 --- /dev/null +++ b/backend/internal/service/identity_service_order_test.go @@ -0,0 +1,82 @@ +package service + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type identityCacheStub struct { + maskedSessionID string +} + +func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) { + return nil, nil +} +func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error { + return nil +} +func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) { + return s.maskedSessionID, nil +} +func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error { + s.maskedSessionID = sessionID + return nil +} + +func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) { + cache := &identityCacheStub{} + svc := NewIdentityService(cache) + + originalUserID := FormatMetadataUserID( + "d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169", + "", + "7578cf37-aaca-46e4-a45c-71285d9dbb83", + "2.1.78", + ) + body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`) + + result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)") + require.NoError(t, err) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`) + require.NotContains(t, resultStr, originalUserID) + require.Contains(t, resultStr, `"metadata":{"user_id":"`) +} + +func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) { + cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"} + svc := NewIdentityService(cache) + + originalUserID := FormatMetadataUserID( + "d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169", + "", + "7578cf37-aaca-46e4-a45c-71285d9dbb83", + "2.1.78", + ) + body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`) + + account := &Account{ + ID: 123, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "session_id_masking_enabled": true, + }, + } + + result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)") + require.NoError(t, err) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`) + require.Contains(t, resultStr, cache.maskedSessionID) + require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`)) +} + +func strconvQuote(v string) string { + return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"` +}