diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go new file mode 100644 index 00000000..88e16a4d --- /dev/null +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -0,0 +1,81 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +const compatPromptCacheKeyPrefix = "compat_cc_" + +func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { + switch normalizeCodexModel(strings.TrimSpace(model)) { + case "gpt-5.4", "gpt-5.3-codex": + return true + default: + return false + } +} + +func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string { + if req == nil { + return "" + } + + normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) + if normalizedModel == "" { + normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) + } + if normalizedModel == "" { + normalizedModel = strings.TrimSpace(req.Model) + } + + seedParts := []string{"model=" + normalizedModel} + if req.ReasoningEffort != "" { + seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.ReasoningEffort)) + } + if len(req.ToolChoice) > 0 { + seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice)) + } + if len(req.Tools) > 0 { + if raw, err := json.Marshal(req.Tools); err == nil { + seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw)) + } + } + if len(req.Functions) > 0 { + if raw, err := json.Marshal(req.Functions); err == nil { + seedParts = append(seedParts, "functions="+normalizeCompatSeedJSON(raw)) + } + } + + firstUserCaptured := false + for _, msg := range req.Messages { + switch strings.TrimSpace(msg.Role) { + case "system": + seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(msg.Content)) + case "user": + if !firstUserCaptured { + seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content)) + firstUserCaptured = true + } + } + } + + return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|")) +} + +func normalizeCompatSeedJSON(v json.RawMessage) string { + if len(v) == 0 { + return "" + } + var tmp any + if err := json.Unmarshal(v, &tmp); err != nil { + return string(v) + } + out, err := json.Marshal(tmp) + if err != nil { + return string(v) + } + return string(out) +} diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go new file mode 100644 index 00000000..eb9148de --- /dev/null +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -0,0 +1,64 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" +) + +func mustRawJSON(t *testing.T, s string) json.RawMessage { + t.Helper() + return json.RawMessage(s) +} + +func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) + require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) +} + +func TestDeriveCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) { + base := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "system", Content: mustRawJSON(t, `"You are helpful."`)}, + {Role: "user", Content: mustRawJSON(t, `"Hello"`)}, + }, + } + extended := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "system", Content: mustRawJSON(t, `"You are helpful."`)}, + {Role: "user", Content: mustRawJSON(t, `"Hello"`)}, + {Role: "assistant", Content: mustRawJSON(t, `"Hi there!"`)}, + {Role: "user", Content: mustRawJSON(t, `"How are you?"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(base, "gpt-5.4") + k2 := deriveCompatPromptCacheKey(extended, "gpt-5.4") + require.Equal(t, k1, k2, "cache key should be stable across later turns") + require.NotEmpty(t, k1) +} + +func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { + req1 := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + req2 := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question B"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req1, "gpt-5.4") + k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") + require.NotEqual(t, k1, k2, "different first user messages should yield different keys") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 7202f7cb..a442da33 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -43,23 +43,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( clientStream := chatReq.Stream includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage - // 2. Convert to Responses and forward + // 2. Resolve model mapping early so compat prompt_cache_key injection can + // derive a stable seed from the final upstream model family. + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + + promptCacheKey = strings.TrimSpace(promptCacheKey) + compatPromptCacheInjected := false + if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) { + promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel) + compatPromptCacheInjected = promptCacheKey != "" + } + + // 3. Convert to Responses and forward // ChatCompletionsToResponses always sets Stream=true (upstream always streams). responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) if err != nil { return nil, fmt.Errorf("convert chat completions to responses: %w", err) } - - // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) responsesReq.Model = mappedModel - logger.L().Debug("openai chat_completions: model mapping applied", + logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), zap.String("mapped_model", mappedModel), zap.Bool("stream", clientStream), - ) + } + if compatPromptCacheInjected { + logFields = append(logFields, + zap.Bool("compat_prompt_cache_key_injected", true), + zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)), + ) + } + logger.L().Debug("openai chat_completions: model mapping applied", logFields...) // 4. Marshal Responses request body, then apply OAuth codex transform responsesBody, err := json.Marshal(responsesReq)