mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
修复 OAuth/SetupToken 转发请求体重排并增加调试开关
This commit is contained in:
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||
var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
||||
func filterSystemBlockByPrefix(text string) string {
|
||||
for _, prefix := range systemBlockFilterPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(sysStr)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(block.Text)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system json.RawMessage
|
||||
}{
|
||||
{
|
||||
name: "system array",
|
||||
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||
},
|
||||
{
|
||||
name: "system string",
|
||||
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
System: tt.system,
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.NotNil(t, req.Request.SystemInstruction)
|
||||
|
||||
found := false
|
||||
for _, part := range req.Request.SystemInstruction.Parts {
|
||||
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
|
||||
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "system array",
|
||||
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||
},
|
||||
{
|
||||
name: "system string",
|
||||
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"x-request-id": []string{"rid-oauth-preserve"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Name: "anthropic-oauth-preserve",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
|
||||
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
|
||||
|
||||
system := gjson.GetBytes(upstream.lastBody, "system")
|
||||
require.True(t, system.Exists())
|
||||
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
72
backend/internal/service/gateway_body_order_test.go
Normal file
72
backend/internal/service/gateway_body_order_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||
t.Helper()
|
||||
|
||||
last := -1
|
||||
for _, token := range tokens {
|
||||
pos := strings.Index(body, token)
|
||||
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
|
||||
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
|
||||
last = pos
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
|
||||
|
||||
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
|
||||
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
|
||||
|
||||
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
|
||||
injectMetadata: true,
|
||||
metadataUserID: "user-1",
|
||||
})
|
||||
resultStr := string(result)
|
||||
|
||||
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||
require.NotContains(t, resultStr, `"temperature"`)
|
||||
require.NotContains(t, resultStr, `"tool_choice"`)
|
||||
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
|
||||
require.Contains(t, resultStr, `"tools":[]`)
|
||||
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
|
||||
|
||||
result := injectClaudeCodePrompt(body, []any{
|
||||
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
|
||||
})
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
|
||||
|
||||
result := enforceCacheControlLimit(body)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||
}
|
||||
34
backend/internal/service/gateway_debug_env_test.go
Normal file
34
backend/internal/service/gateway_debug_env_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
|
||||
t.Run("default disabled", func(t *testing.T) {
|
||||
t.Setenv(debugGatewayBodyEnv, "")
|
||||
if debugGatewayBodyLoggingEnabled() {
|
||||
t.Fatalf("expected debug gateway body logging to be disabled by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled with true-like values", func(t *testing.T) {
|
||||
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
|
||||
t.Run(value, func(t *testing.T) {
|
||||
t.Setenv(debugGatewayBodyEnv, value)
|
||||
if !debugGatewayBodyLoggingEnabled() {
|
||||
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled with other values", func(t *testing.T) {
|
||||
for _, value := range []string{"0", "false", "off", "debug"} {
|
||||
t.Run(value, func(t *testing.T) {
|
||||
t.Setenv(debugGatewayBodyEnv, value)
|
||||
if debugGatewayBodyLoggingEnabled() {
|
||||
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -51,6 +51,7 @@ const (
|
||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -339,12 +340,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
|
||||
var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// ErrNoAvailableAccounts 表示没有可用的账号
|
||||
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
||||
|
||||
@@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string {
|
||||
return strconv.FormatUint(h, 36)
|
||||
}
|
||||
|
||||
type anthropicCacheControlPayload struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type anthropicSystemTextBlockPayload struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMetadataPayload struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
|
||||
// 优先使用定点修改,尽量保持客户端原始字段顺序。
|
||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||
var req map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
// 只序列化 model 字段
|
||||
modelBytes, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
||||
return body
|
||||
}
|
||||
req["model"] = modelBytes
|
||||
newBody, err := json.Marshal(req)
|
||||
newBody, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
@@ -884,24 +889,146 @@ func sanitizeSystemText(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
blocks, ok := system.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) {
|
||||
block := anthropicSystemTextBlockPayload{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
}
|
||||
changed := false
|
||||
for _, item := range blocks {
|
||||
block, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := block["cache_control"]; !exists {
|
||||
continue
|
||||
}
|
||||
delete(block, "cache_control")
|
||||
changed = true
|
||||
if includeCacheControl {
|
||||
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
|
||||
}
|
||||
return changed
|
||||
return json.Marshal(block)
|
||||
}
|
||||
|
||||
func marshalAnthropicMetadata(userID string) ([]byte, error) {
|
||||
return json.Marshal(anthropicMetadataPayload{UserID: userID})
|
||||
}
|
||||
|
||||
func buildJSONArrayRaw(items [][]byte) []byte {
|
||||
if len(items) == 0 {
|
||||
return []byte("[]")
|
||||
}
|
||||
|
||||
total := 2
|
||||
for _, item := range items {
|
||||
total += len(item)
|
||||
}
|
||||
total += len(items) - 1
|
||||
|
||||
buf := make([]byte, 0, total)
|
||||
buf = append(buf, '[')
|
||||
for i, item := range items {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
buf = append(buf, item...)
|
||||
}
|
||||
buf = append(buf, ']')
|
||||
return buf
|
||||
}
|
||||
|
||||
func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) {
|
||||
next, err := sjson.SetBytes(body, path, value)
|
||||
if err != nil {
|
||||
return body, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) {
|
||||
next, err := sjson.SetRawBytes(body, path, raw)
|
||||
if err != nil {
|
||||
return body, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) {
|
||||
next, err := sjson.DeleteBytes(body, path)
|
||||
if err != nil {
|
||||
return body, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) {
|
||||
sys := gjson.GetBytes(body, "system")
|
||||
if !sys.Exists() {
|
||||
return body, false
|
||||
}
|
||||
|
||||
out := body
|
||||
modified := false
|
||||
|
||||
switch {
|
||||
case sys.Type == gjson.String:
|
||||
sanitized := sanitizeSystemText(sys.String())
|
||||
if sanitized != sys.String() {
|
||||
if next, ok := setJSONValueBytes(out, "system", sanitized); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
case sys.IsArray():
|
||||
index := 0
|
||||
sys.ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("type").String() == "text" {
|
||||
textResult := item.Get("text")
|
||||
if textResult.Exists() && textResult.Type == gjson.String {
|
||||
text := textResult.String()
|
||||
sanitized := sanitizeSystemText(text)
|
||||
if sanitized != text {
|
||||
if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.stripSystemCacheControl && item.Get("cache_control").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
index++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return out, modified
|
||||
}
|
||||
|
||||
func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) {
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
return body, false
|
||||
}
|
||||
|
||||
metadata := gjson.GetBytes(body, "metadata")
|
||||
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||
raw, err := marshalAnthropicMetadata(userID)
|
||||
if err != nil {
|
||||
return body, false
|
||||
}
|
||||
return setJSONRawBytes(body, "metadata", raw)
|
||||
}
|
||||
|
||||
trimmedRaw := strings.TrimSpace(metadata.Raw)
|
||||
if strings.HasPrefix(trimmedRaw, "{") {
|
||||
existing := metadata.Get("user_id")
|
||||
if existing.Exists() && existing.Type == gjson.String && existing.String() != "" {
|
||||
return body, false
|
||||
}
|
||||
return setJSONValueBytes(body, "metadata.user_id", userID)
|
||||
}
|
||||
|
||||
raw, err := marshalAnthropicMetadata(userID)
|
||||
if err != nil {
|
||||
return body, false
|
||||
}
|
||||
return setJSONRawBytes(body, "metadata", raw)
|
||||
}
|
||||
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
|
||||
@@ -909,96 +1036,59 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
// 解析为 map[string]any 用于修改字段
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
out := body
|
||||
modified := false
|
||||
|
||||
if system, ok := req["system"]; ok {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
sanitized := sanitizeSystemText(v)
|
||||
if sanitized != v {
|
||||
req["system"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
block, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if blockType, _ := block["type"].(string); blockType != "text" {
|
||||
continue
|
||||
}
|
||||
text, ok := block["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
sanitized := sanitizeSystemText(text)
|
||||
if sanitized != text {
|
||||
block["text"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
|
||||
if rawModel, ok := req["model"].(string); ok {
|
||||
normalized := claude.NormalizeModelID(rawModel)
|
||||
if normalized != rawModel {
|
||||
req["model"] = normalized
|
||||
rawModel := gjson.GetBytes(out, "model")
|
||||
if rawModel.Exists() && rawModel.Type == gjson.String {
|
||||
normalized := claude.NormalizeModelID(rawModel.String())
|
||||
if normalized != rawModel.String() {
|
||||
if next, ok := setJSONValueBytes(out, "model", normalized); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
modelID = normalized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 tools 字段存在(即使为空数组)
|
||||
if _, exists := req["tools"]; !exists {
|
||||
req["tools"] = []any{}
|
||||
modified = true
|
||||
}
|
||||
|
||||
if opts.stripSystemCacheControl {
|
||||
if system, ok := req["system"]; ok {
|
||||
_ = stripCacheControlFromSystemBlocks(system)
|
||||
if !gjson.GetBytes(out, "tools").Exists() {
|
||||
if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
if opts.injectMetadata && opts.metadataUserID != "" {
|
||||
metadata, ok := req["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
metadata = map[string]any{}
|
||||
req["metadata"] = metadata
|
||||
}
|
||||
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
|
||||
metadata["user_id"] = opts.metadataUserID
|
||||
if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
if _, hasTemp := req["temperature"]; hasTemp {
|
||||
delete(req, "temperature")
|
||||
modified = true
|
||||
if gjson.GetBytes(out, "temperature").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, hasChoice := req["tool_choice"]; hasChoice {
|
||||
delete(req, "tool_choice")
|
||||
modified = true
|
||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID
|
||||
}
|
||||
return newBody, modelID
|
||||
return out, modelID
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||
@@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
|
||||
func matchesFilterPrefix(text string) bool {
|
||||
for _, prefix := range systemBlockFilterPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
|
||||
// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system)
|
||||
func filterSystemBlocksByPrefix(body []byte) []byte {
|
||||
sys := gjson.GetBytes(body, "system")
|
||||
if !sys.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
switch {
|
||||
case sys.Type == gjson.String:
|
||||
if matchesFilterPrefix(sys.Str) {
|
||||
result, err := sjson.DeleteBytes(body, "system")
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
case sys.IsArray():
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
|
||||
return body
|
||||
}
|
||||
filtered := make([]any, 0, len(parsed))
|
||||
changed := false
|
||||
for _, item := range parsed {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
if changed {
|
||||
result, err := sjson.SetBytes(body, "system", filtered)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||
// 处理 null、字符串、数组三种格式
|
||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
claudeCodeBlock := map[string]any{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{"type": "ephemeral"},
|
||||
claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
|
||||
return body
|
||||
}
|
||||
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
|
||||
// banner, it also prefixes the next system instruction with the same banner plus
|
||||
// a blank line. This helps when upstream concatenates system instructions.
|
||||
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||
|
||||
var newSystem []any
|
||||
var items [][]byte
|
||||
|
||||
switch v := system.(type) {
|
||||
case nil:
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
items = [][]byte{claudeCodeBlock}
|
||||
case string:
|
||||
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
|
||||
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
items = [][]byte{claudeCodeBlock}
|
||||
} else {
|
||||
// Mirror opencode behavior: keep the banner as a separate system entry,
|
||||
// but also prefix the next system text with the banner.
|
||||
@@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
if !strings.HasPrefix(v, claudeCodePrefix) {
|
||||
merged = claudeCodePrefix + "\n\n" + v
|
||||
}
|
||||
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
|
||||
nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false)
|
||||
if buildErr != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr)
|
||||
return body
|
||||
}
|
||||
items = [][]byte{claudeCodeBlock, nextBlock}
|
||||
}
|
||||
case []any:
|
||||
newSystem = make([]any, 0, len(v)+1)
|
||||
newSystem = append(newSystem, claudeCodeBlock)
|
||||
items = make([][]byte, 0, len(v)+1)
|
||||
items = append(items, claudeCodeBlock)
|
||||
prefixedNext := false
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
systemResult := gjson.GetBytes(body, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||
textResult := item.Get("text")
|
||||
if textResult.Exists() && textResult.Type == gjson.String &&
|
||||
strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||
return true
|
||||
}
|
||||
|
||||
raw := []byte(item.Raw)
|
||||
// Prefix the first subsequent text system block once.
|
||||
if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String {
|
||||
text := textResult.String()
|
||||
if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
|
||||
next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text)
|
||||
if setErr == nil {
|
||||
raw = next
|
||||
prefixedNext = true
|
||||
}
|
||||
}
|
||||
}
|
||||
items = append(items, raw)
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
for _, item := range v {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
raw, marshalErr := json.Marshal(item)
|
||||
if marshalErr == nil {
|
||||
items = append(items, raw)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||
continue
|
||||
}
|
||||
// Prefix the first subsequent text system block once.
|
||||
if !prefixedNext {
|
||||
if blockType, _ := m["type"].(string); blockType == "text" {
|
||||
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
|
||||
@@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
}
|
||||
}
|
||||
}
|
||||
raw, marshalErr := json.Marshal(m)
|
||||
if marshalErr == nil {
|
||||
items = append(items, raw)
|
||||
}
|
||||
}
|
||||
newSystem = append(newSystem, item)
|
||||
}
|
||||
default:
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
items = [][]byte{claudeCodeBlock}
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "system", newSystem)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err)
|
||||
result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items))
|
||||
if !ok {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt")
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type cacheControlPath struct {
|
||||
path string
|
||||
log string
|
||||
}
|
||||
|
||||
func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) {
|
||||
system := gjson.GetBytes(body, "system")
|
||||
if system.IsArray() {
|
||||
sysIndex := 0
|
||||
system.ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("cache_control").Exists() {
|
||||
path := fmt.Sprintf("system.%d.cache_control", sysIndex)
|
||||
if item.Get("type").String() == "thinking" {
|
||||
invalidThinking = append(invalidThinking, cacheControlPath{
|
||||
path: path,
|
||||
log: "[Warning] Removed illegal cache_control from thinking block in system",
|
||||
})
|
||||
} else {
|
||||
systemPaths = append(systemPaths, path)
|
||||
}
|
||||
}
|
||||
sysIndex++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if messages.IsArray() {
|
||||
msgIndex := 0
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if content.IsArray() {
|
||||
contentIndex := 0
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("cache_control").Exists() {
|
||||
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex)
|
||||
if item.Get("type").String() == "thinking" {
|
||||
invalidThinking = append(invalidThinking, cacheControlPath{
|
||||
path: path,
|
||||
log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex),
|
||||
})
|
||||
} else {
|
||||
messagePaths = append(messagePaths, path)
|
||||
}
|
||||
}
|
||||
contentIndex++
|
||||
return true
|
||||
})
|
||||
}
|
||||
msgIndex++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return invalidThinking, messagePaths, systemPaths
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
||||
func enforceCacheControlLimit(body []byte) []byte {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||
removeCacheControlFromThinkingBlocks(data)
|
||||
invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body)
|
||||
out := body
|
||||
modified := false
|
||||
|
||||
// 计算当前 cache_control 块数量
|
||||
count := countCacheControlBlocks(data)
|
||||
// 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||
for _, item := range invalidThinking {
|
||||
if !gjson.GetBytes(out, item.path).Exists() {
|
||||
continue
|
||||
}
|
||||
next, ok := deleteJSONPathBytes(out, item.path)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = next
|
||||
modified = true
|
||||
logger.LegacyPrintf("service.gateway", "%s", item.log)
|
||||
}
|
||||
|
||||
count := len(messagePaths) + len(systemPaths)
|
||||
if count <= maxCacheControlBlocks {
|
||||
if modified {
|
||||
return out
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// 超限:优先从 messages 中移除,再从 system 中移除
|
||||
for count > maxCacheControlBlocks {
|
||||
if removeCacheControlFromMessages(data) {
|
||||
count--
|
||||
remaining := count - maxCacheControlBlocks
|
||||
for _, path := range messagePaths {
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
if !gjson.GetBytes(out, path).Exists() {
|
||||
continue
|
||||
}
|
||||
if removeCacheControlFromSystem(data) {
|
||||
count--
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
result, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||
// 注意:thinking 块不支持 cache_control,统计时跳过
|
||||
func countCacheControlBlocks(data map[string]any) int {
|
||||
count := 0
|
||||
|
||||
// 统计 system 中的块
|
||||
if system, ok := data["system"].([]any); ok {
|
||||
for _, item := range system {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
// thinking 块不支持 cache_control,跳过
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
continue
|
||||
}
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 统计 messages 中的块
|
||||
if messages, ok := data["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if content, ok := msgMap["content"].([]any); ok {
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
// thinking 块不支持 cache_control,跳过
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
continue
|
||||
}
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||
messages, ok := data["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
next, ok := deleteJSONPathBytes(out, path)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
out = next
|
||||
modified = true
|
||||
remaining--
|
||||
}
|
||||
|
||||
for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- {
|
||||
path := systemPaths[i]
|
||||
if !gjson.GetBytes(out, path).Exists() {
|
||||
continue
|
||||
}
|
||||
next, ok := deleteJSONPathBytes(out, path)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
// thinking 块不支持 cache_control,跳过
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
continue
|
||||
}
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||
system, ok := data["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
out = next
|
||||
modified = true
|
||||
remaining--
|
||||
}
|
||||
|
||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||
for i := len(system) - 1; i >= 0; i-- {
|
||||
if m, ok := system[i].(map[string]any); ok {
|
||||
// thinking 块不支持 cache_control,跳过
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
continue
|
||||
}
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
|
||||
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
|
||||
func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||
// 清理 system 中的 thinking 块
|
||||
if system, ok := data["system"].([]any); ok {
|
||||
for _, item := range system {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 清理 messages 中的 thinking 块
|
||||
if messages, ok := data["messages"].([]any); ok {
|
||||
for msgIdx, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if content, ok := msgMap["content"].([]any); ok {
|
||||
for contentIdx, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if modified {
|
||||
return out
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
@@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
reqStream := parsed.Stream
|
||||
originalModel := reqModel
|
||||
|
||||
// === DEBUG: 打印客户端原始请求 body ===
|
||||
debugLogRequestBody("CLIENT_ORIGINAL", body)
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
@@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
|
||||
// 放在 inject/normalize 之后,确保不会被覆盖
|
||||
if account.IsOAuth() {
|
||||
body = filterSystemBlocksByPrefix(body)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
@@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// === DEBUG: 打印转发给上游的 body(metadata 已重写) ===
|
||||
debugLogRequestBody("UPSTREAM_FORWARD", body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool {
|
||||
usage["cache_read_input_tokens"] = cached
|
||||
return true
|
||||
}
|
||||
|
||||
func debugGatewayBodyLoggingEnabled() bool {
|
||||
raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv))
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch strings.ToLower(raw) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。
|
||||
// 默认关闭,仅在设置环境变量时启用:
|
||||
//
|
||||
// SUB2API_DEBUG_GATEWAY_BODY=1
|
||||
func debugLogRequestBody(tag string, body []byte) {
|
||||
if !debugGatewayBodyLoggingEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag)
|
||||
return
|
||||
}
|
||||
|
||||
// 提取 metadata 字段完整打印
|
||||
metadataResult := gjson.GetBytes(body, "metadata")
|
||||
if metadataResult.Exists() {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag)
|
||||
}
|
||||
|
||||
// 全量打印 body
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body))
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
@@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 使用 RawMessage 保留其他字段的原始字节
|
||||
var reqMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
metadata := gjson.GetBytes(body, "metadata")
|
||||
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||
return body, nil
|
||||
}
|
||||
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 解析 metadata 字段
|
||||
metadataRaw, ok := reqMap["metadata"]
|
||||
if !ok {
|
||||
userIDResult := metadata.Get("user_id")
|
||||
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
userID := userIDResult.String()
|
||||
if userID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
// 根据客户端版本选择输出格式
|
||||
version := ExtractCLIVersion(fingerprintUA)
|
||||
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
|
||||
if newUserID == userID {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
|
||||
// 只重新序列化 metadata 字段
|
||||
newMetadataRaw, err := json.Marshal(metadata)
|
||||
newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
}
|
||||
reqMap["metadata"] = newMetadataRaw
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||
@@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 使用 RawMessage 保留其他字段的原始字节
|
||||
var reqMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||
metadata := gjson.GetBytes(newBody, "metadata")
|
||||
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||
return newBody, nil
|
||||
}
|
||||
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 解析 metadata 字段
|
||||
metadataRaw, ok := reqMap["metadata"]
|
||||
if !ok {
|
||||
userIDResult := metadata.Get("user_id")
|
||||
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
userID := userIDResult.String()
|
||||
if userID == "" {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
@@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
||||
"after", newUserID,
|
||||
)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
|
||||
// 只重新序列化 metadata 字段
|
||||
newMetadataRaw, marshalErr := json.Marshal(metadata)
|
||||
if marshalErr != nil {
|
||||
if newUserID == userID {
|
||||
return newBody, nil
|
||||
}
|
||||
reqMap["metadata"] = newMetadataRaw
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
|
||||
if setErr != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
return maskedBody, nil
|
||||
}
|
||||
|
||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||
|
||||
82
backend/internal/service/identity_service_order_test.go
Normal file
82
backend/internal/service/identity_service_order_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type identityCacheStub struct {
|
||||
maskedSessionID string
|
||||
}
|
||||
|
||||
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
|
||||
return nil
|
||||
}
|
||||
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
|
||||
return s.maskedSessionID, nil
|
||||
}
|
||||
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
|
||||
s.maskedSessionID = sessionID
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
cache := &identityCacheStub{}
|
||||
svc := NewIdentityService(cache)
|
||||
|
||||
originalUserID := FormatMetadataUserID(
|
||||
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||
"",
|
||||
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||
"2.1.78",
|
||||
)
|
||||
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||
|
||||
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||
require.NoError(t, err)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||
require.NotContains(t, resultStr, originalUserID)
|
||||
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
|
||||
}
|
||||
|
||||
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
|
||||
svc := NewIdentityService(cache)
|
||||
|
||||
originalUserID := FormatMetadataUserID(
|
||||
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||
"",
|
||||
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||
"2.1.78",
|
||||
)
|
||||
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"session_id_masking_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||
require.NoError(t, err)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||
require.Contains(t, resultStr, cache.maskedSessionID)
|
||||
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
|
||||
}
|
||||
|
||||
func strconvQuote(v string) string {
|
||||
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
|
||||
}
|
||||
Reference in New Issue
Block a user