修复 OAuth/SetupToken 转发请求体重排并增加调试开关

This commit is contained in:
shaw
2026-03-19 16:44:39 +08:00
parent 9f6ab6b817
commit a6764e82f2
8 changed files with 742 additions and 388 deletions

View File

@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
return ""
}
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
func filterSystemBlockByPrefix(text string) string {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return ""
}
}
return text
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}

View File

@@ -2,7 +2,10 @@ package antigravity
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
})
}
}
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
tests := []struct {
name string
system json.RawMessage
}{
{
name: "system array",
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
},
{
name: "system string",
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-3-5-sonnet-latest",
System: tt.system,
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.NotNil(t, req.Request.SystemInstruction)
found := false
for _, part := range req.Request.SystemInstruction.Parts {
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
found = true
break
}
}
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
})
}
}

View File

@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
}
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body string
}{
{
name: "system array",
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
{
name: "system string",
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
require.NoError(t, err)
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-oauth-preserve"},
},
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
},
}
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
}
account := &Account{
ID: 301,
Name: "anthropic-oauth-preserve",
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
})
}
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -0,0 +1,72 @@
package service
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
)
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
last := -1
for _, token := range tokens {
pos := strings.Index(body, token)
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
last = pos
}
}
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
svc := &GatewayService{}
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
}
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
injectMetadata: true,
metadataUserID: "user-1",
})
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
result := injectClaudeCodePrompt(body, []any{
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
})
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
}
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
result := enforceCacheControlLimit(body)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}

View File

@@ -0,0 +1,34 @@
package service
import "testing"
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
t.Run("default disabled", func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, "")
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled by default")
}
})
t.Run("enabled with true-like values", func(t *testing.T) {
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if !debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
}
})
}
})
t.Run("disabled with other values", func(t *testing.T) {
for _, value := range []string{"0", "false", "off", "debug"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
}
})
}
})
}

View File

@@ -51,6 +51,7 @@ const (
defaultUserGroupRateCacheTTL = 30 * time.Second
defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
)
const (
@@ -339,12 +340,6 @@ var (
}
)
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// ErrNoAvailableAccounts 表示没有可用的账号
var ErrNoAvailableAccounts = errors.New("no available accounts")
@@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string {
return strconv.FormatUint(h, 36)
}
type anthropicCacheControlPayload struct {
Type string `json:"type"`
}
type anthropicSystemTextBlockPayload struct {
Type string `json:"type"`
Text string `json:"text"`
CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"`
}
type anthropicMetadataPayload struct {
UserID string `json:"user_id"`
}
// replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
// 优先使用定点修改,尽量保持客户端原始字段顺序。
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]json.RawMessage
if err := json.Unmarshal(body, &req); err != nil {
if len(body) == 0 {
return body
}
// 只序列化 model 字段
modelBytes, err := json.Marshal(newModel)
if err != nil {
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
req["model"] = modelBytes
newBody, err := json.Marshal(req)
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
@@ -884,24 +889,146 @@ func sanitizeSystemText(text string) string {
return text
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
return false
func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) {
block := anthropicSystemTextBlockPayload{
Type: "text",
Text: text,
}
changed := false
for _, item := range blocks {
block, ok := item.(map[string]any)
if !ok {
continue
}
if _, exists := block["cache_control"]; !exists {
continue
}
delete(block, "cache_control")
changed = true
if includeCacheControl {
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
}
return changed
return json.Marshal(block)
}
func marshalAnthropicMetadata(userID string) ([]byte, error) {
return json.Marshal(anthropicMetadataPayload{UserID: userID})
}
func buildJSONArrayRaw(items [][]byte) []byte {
if len(items) == 0 {
return []byte("[]")
}
total := 2
for _, item := range items {
total += len(item)
}
total += len(items) - 1
buf := make([]byte, 0, total)
buf = append(buf, '[')
for i, item := range items {
if i > 0 {
buf = append(buf, ',')
}
buf = append(buf, item...)
}
buf = append(buf, ']')
return buf
}
func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) {
next, err := sjson.SetBytes(body, path, value)
if err != nil {
return body, false
}
return next, true
}
func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) {
next, err := sjson.SetRawBytes(body, path, raw)
if err != nil {
return body, false
}
return next, true
}
func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) {
next, err := sjson.DeleteBytes(body, path)
if err != nil {
return body, false
}
return next, true
}
func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body, false
}
out := body
modified := false
switch {
case sys.Type == gjson.String:
sanitized := sanitizeSystemText(sys.String())
if sanitized != sys.String() {
if next, ok := setJSONValueBytes(out, "system", sanitized); ok {
out = next
modified = true
}
}
case sys.IsArray():
index := 0
sys.ForEach(func(_, item gjson.Result) bool {
if item.Get("type").String() == "text" {
textResult := item.Get("text")
if textResult.Exists() && textResult.Type == gjson.String {
text := textResult.String()
sanitized := sanitizeSystemText(text)
if sanitized != text {
if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok {
out = next
modified = true
}
}
}
}
if opts.stripSystemCacheControl && item.Get("cache_control").Exists() {
if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok {
out = next
modified = true
}
}
index++
return true
})
}
return out, modified
}
func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) {
if strings.TrimSpace(userID) == "" {
return body, false
}
metadata := gjson.GetBytes(body, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
raw, err := marshalAnthropicMetadata(userID)
if err != nil {
return body, false
}
return setJSONRawBytes(body, "metadata", raw)
}
trimmedRaw := strings.TrimSpace(metadata.Raw)
if strings.HasPrefix(trimmedRaw, "{") {
existing := metadata.Get("user_id")
if existing.Exists() && existing.Type == gjson.String && existing.String() != "" {
return body, false
}
return setJSONValueBytes(body, "metadata.user_id", userID)
}
raw, err := marshalAnthropicMetadata(userID)
if err != nil {
return body, false
}
return setJSONRawBytes(body, "metadata", raw)
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
@@ -909,96 +1036,59 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
return body, modelID
}
// 解析为 map[string]any 用于修改字段
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID
}
out := body
modified := false
if system, ok := req["system"]; ok {
switch v := system.(type) {
case string:
sanitized := sanitizeSystemText(v)
if sanitized != v {
req["system"] = sanitized
modified = true
}
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
if blockType, _ := block["type"].(string); blockType != "text" {
continue
}
text, ok := block["text"].(string)
if !ok || text == "" {
continue
}
sanitized := sanitizeSystemText(text)
if sanitized != text {
block["text"] = sanitized
modified = true
}
}
}
if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed {
out = next
modified = true
}
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
req["model"] = normalized
rawModel := gjson.GetBytes(out, "model")
if rawModel.Exists() && rawModel.Type == gjson.String {
normalized := claude.NormalizeModelID(rawModel.String())
if normalized != rawModel.String() {
if next, ok := setJSONValueBytes(out, "model", normalized); ok {
out = next
modified = true
}
modelID = normalized
modified = true
}
}
// 确保 tools 字段存在(即使为空数组)
if _, exists := req["tools"]; !exists {
req["tools"] = []any{}
modified = true
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
if !gjson.GetBytes(out, "tools").Exists() {
if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok {
out = next
modified = true
}
}
if opts.injectMetadata && opts.metadataUserID != "" {
metadata, ok := req["metadata"].(map[string]any)
if !ok {
metadata = map[string]any{}
req["metadata"] = metadata
}
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID
if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed {
out = next
modified = true
}
}
if _, hasTemp := req["temperature"]; hasTemp {
delete(req, "temperature")
modified = true
if gjson.GetBytes(out, "temperature").Exists() {
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
out = next
modified = true
}
}
if _, hasChoice := req["tool_choice"]; hasChoice {
delete(req, "tool_choice")
modified = true
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
if !modified {
return body, modelID
}
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID
}
return newBody, modelID
return out, modelID
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
@@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool {
return false
}
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
func matchesFilterPrefix(text string) bool {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
// 直接从 body 解析 system不依赖外部传入的 parsed.System因为前置步骤可能已修改 body 中的 system
func filterSystemBlocksByPrefix(body []byte) []byte {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body
}
switch {
case sys.Type == gjson.String:
if matchesFilterPrefix(sys.Str) {
result, err := sjson.DeleteBytes(body, "system")
if err != nil {
return body
}
return result
}
case sys.IsArray():
var parsed []any
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
return body
}
filtered := make([]any, 0, len(parsed))
changed := false
for _, item := range parsed {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
changed = true
continue
}
}
filtered = append(filtered, item)
}
if changed {
result, err := sjson.SetBytes(body, "system", filtered)
if err != nil {
return body
}
return result
}
}
return body
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
claudeCodeBlock := map[string]any{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if err != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
return body
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
var newSystem []any
var items [][]byte
switch v := system.(type) {
case nil:
newSystem = []any{claudeCodeBlock}
items = [][]byte{claudeCodeBlock}
case string:
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
newSystem = []any{claudeCodeBlock}
items = [][]byte{claudeCodeBlock}
} else {
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
@@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
if !strings.HasPrefix(v, claudeCodePrefix) {
merged = claudeCodePrefix + "\n\n" + v
}
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false)
if buildErr != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr)
return body
}
items = [][]byte{claudeCodeBlock, nextBlock}
}
case []any:
newSystem = make([]any, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock)
items = make([][]byte, 0, len(v)+1)
items = append(items, claudeCodeBlock)
prefixedNext := false
for _, item := range v {
if m, ok := item.(map[string]any); ok {
systemResult := gjson.GetBytes(body, "system")
if systemResult.IsArray() {
systemResult.ForEach(func(_, item gjson.Result) bool {
textResult := item.Get("text")
if textResult.Exists() && textResult.Type == gjson.String &&
strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) {
return true
}
raw := []byte(item.Raw)
// Prefix the first subsequent text system block once.
if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String {
text := textResult.String()
if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text)
if setErr == nil {
raw = next
prefixedNext = true
}
}
}
items = append(items, raw)
return true
})
} else {
for _, item := range v {
m, ok := item.(map[string]any)
if !ok {
raw, marshalErr := json.Marshal(item)
if marshalErr == nil {
items = append(items, raw)
}
continue
}
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
continue
}
// Prefix the first subsequent text system block once.
if !prefixedNext {
if blockType, _ := m["type"].(string); blockType == "text" {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
@@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
}
}
}
raw, marshalErr := json.Marshal(m)
if marshalErr == nil {
items = append(items, raw)
}
}
newSystem = append(newSystem, item)
}
default:
newSystem = []any{claudeCodeBlock}
items = [][]byte{claudeCodeBlock}
}
result, err := sjson.SetBytes(body, "system", newSystem)
if err != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err)
result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items))
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt")
return body
}
return result
}
type cacheControlPath struct {
path string
log string
}
func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) {
system := gjson.GetBytes(body, "system")
if system.IsArray() {
sysIndex := 0
system.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
path := fmt.Sprintf("system.%d.cache_control", sysIndex)
if item.Get("type").String() == "thinking" {
invalidThinking = append(invalidThinking, cacheControlPath{
path: path,
log: "[Warning] Removed illegal cache_control from thinking block in system",
})
} else {
systemPaths = append(systemPaths, path)
}
}
sysIndex++
return true
})
}
messages := gjson.GetBytes(body, "messages")
if messages.IsArray() {
msgIndex := 0
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
contentIndex := 0
content.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex)
if item.Get("type").String() == "thinking" {
invalidThinking = append(invalidThinking, cacheControlPath{
path: path,
log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex),
})
} else {
messagePaths = append(messagePaths, path)
}
}
contentIndex++
return true
})
}
msgIndex++
return true
})
}
return invalidThinking, messagePaths, systemPaths
}
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
// 超限时优先从 messages 中移除 cache_control保护 system 中的缓存控制
func enforceCacheControlLimit(body []byte) []byte {
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
if len(body) == 0 {
return body
}
// 清理 thinking 块中的非法 cache_controlthinking 块不支持该字段)
removeCacheControlFromThinkingBlocks(data)
invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body)
out := body
modified := false
// 计算当前 cache_control 块数量
count := countCacheControlBlocks(data)
// 先清理 thinking 块中的非法 cache_controlthinking 块不支持该字段)
for _, item := range invalidThinking {
if !gjson.GetBytes(out, item.path).Exists() {
continue
}
next, ok := deleteJSONPathBytes(out, item.path)
if !ok {
continue
}
out = next
modified = true
logger.LegacyPrintf("service.gateway", "%s", item.log)
}
count := len(messagePaths) + len(systemPaths)
if count <= maxCacheControlBlocks {
if modified {
return out
}
return body
}
// 超限:优先从 messages 中移除,再从 system 中移除
for count > maxCacheControlBlocks {
if removeCacheControlFromMessages(data) {
count--
remaining := count - maxCacheControlBlocks
for _, path := range messagePaths {
if remaining <= 0 {
break
}
if !gjson.GetBytes(out, path).Exists() {
continue
}
if removeCacheControlFromSystem(data) {
count--
continue
}
break
}
result, err := json.Marshal(data)
if err != nil {
return body
}
return result
}
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
// 注意thinking 块不支持 cache_control统计时跳过
func countCacheControlBlocks(data map[string]any) int {
count := 0
// 统计 system 中的块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
count++
}
}
}
}
// 统计 messages 中的块
if messages, ok := data["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for _, item := range content {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
count++
}
}
}
}
}
}
}
return count
}
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control从头开始
// 返回 true 表示成功移除false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control
func removeCacheControlFromMessages(data map[string]any) bool {
messages, ok := data["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
next, ok := deleteJSONPathBytes(out, path)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
out = next
modified = true
remaining--
}
for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- {
path := systemPaths[i]
if !gjson.GetBytes(out, path).Exists() {
continue
}
next, ok := deleteJSONPathBytes(out, path)
if !ok {
continue
}
for _, item := range content {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
}
return false
}
// removeCacheControlFromSystem 从 system 中移除一个 cache_control从尾部开始保护注入的 prompt
// 返回 true 表示成功移除false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control
func removeCacheControlFromSystem(data map[string]any) bool {
system, ok := data["system"].([]any)
if !ok {
return false
out = next
modified = true
remaining--
}
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for i := len(system) - 1; i >= 0; i-- {
if m, ok := system[i].(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
return false
}
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
func removeCacheControlFromThinkingBlocks(data map[string]any) {
// 清理 system 中的 thinking 块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system")
}
}
}
}
}
// 清理 messages 中的 thinking 块
if messages, ok := data["messages"].([]any); ok {
for msgIdx, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for contentIdx, item := range content {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
}
}
}
}
}
}
}
if modified {
return out
}
return body
}
// Forward 转发请求到Claude API
@@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqStream := parsed.Stream
originalModel := reqModel
// === DEBUG: 打印客户端原始请求 body ===
debugLogRequestBody("CLIENT_ORIGINAL", body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
@@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
// 放在 inject/normalize 之后,确保不会被覆盖
if account.IsOAuth() {
body = filterSystemBlocksByPrefix(body)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
@@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// === DEBUG: 打印转发给上游的 bodymetadata 已重写) ===
debugLogRequestBody("UPSTREAM_FORWARD", body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool {
usage["cache_read_input_tokens"] = cached
return true
}
func debugGatewayBodyLoggingEnabled() bool {
raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv))
if raw == "" {
return false
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。
// 默认关闭,仅在设置环境变量时启用:
//
// SUB2API_DEBUG_GATEWAY_BODY=1
func debugLogRequestBody(tag string, body []byte) {
if !debugGatewayBodyLoggingEnabled() {
return
}
if len(body) == 0 {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag)
return
}
// 提取 metadata 字段完整打印
metadataResult := gjson.GetBytes(body, "metadata")
if metadataResult.Exists() {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw)
} else {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag)
}
// 全量打印 body
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body))
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net/http"
@@ -15,6 +14,8 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 预编译正则表达式(避免每次调用重新编译)
@@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
if err := json.Unmarshal(body, &reqMap); err != nil {
metadata := gjson.GetBytes(body, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
return body, nil
}
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
return body, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
userIDResult := metadata.Get("user_id")
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return body, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return body, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
userID := userIDResult.String()
if userID == "" {
return body, nil
}
@@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// 根据客户端版本选择输出格式
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
if newUserID == userID {
return body, nil
}
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, err := json.Marshal(metadata)
newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
if err != nil {
return body, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap)
return newBody, nil
}
// RewriteUserIDWithMasking 重写body中的metadata.user_id支持会话ID伪装
@@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
if err := json.Unmarshal(newBody, &reqMap); err != nil {
metadata := gjson.GetBytes(newBody, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
return newBody, nil
}
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
return newBody, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
userIDResult := metadata.Get("user_id")
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return newBody, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return newBody, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
userID := userIDResult.String()
if userID == "" {
return newBody, nil
}
@@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
"after", newUserID,
)
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, marshalErr := json.Marshal(metadata)
if marshalErr != nil {
if newUserID == userID {
return newBody, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap)
maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
if setErr != nil {
return newBody, nil
}
return maskedBody, nil
}
// generateRandomUUID 生成随机 UUID v4 格式字符串

View File

@@ -0,0 +1,82 @@
package service
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type identityCacheStub struct {
maskedSessionID string
}
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
return nil, nil
}
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
return nil
}
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
return s.maskedSessionID, nil
}
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
s.maskedSessionID = sessionID
return nil
}
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.NotContains(t, resultStr, originalUserID)
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
}
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
account := &Account{
ID: 123,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"session_id_masking_enabled": true,
},
}
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.Contains(t, resultStr, cache.maskedSessionID)
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
}
func strconvQuote(v string) string {
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
}