Merge pull request #1134 from yasu-dev221/fix/openai-compat-prompt-cache-key

fix(openai): add fallback prompt_cache_key for compat codex OAuth requests
This commit is contained in:
Wesley Liddick
2026-03-19 22:02:08 +08:00
committed by GitHub
3 changed files with 166 additions and 6 deletions

View File

@@ -0,0 +1,81 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch normalizeCodexModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex":
return true
default:
return false
}
}
func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string {
if req == nil {
return ""
}
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" {
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)
}
seedParts := []string{"model=" + normalizedModel}
if req.ReasoningEffort != "" {
seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.ReasoningEffort))
}
if len(req.ToolChoice) > 0 {
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
}
if len(req.Tools) > 0 {
if raw, err := json.Marshal(req.Tools); err == nil {
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
}
}
if len(req.Functions) > 0 {
if raw, err := json.Marshal(req.Functions); err == nil {
seedParts = append(seedParts, "functions="+normalizeCompatSeedJSON(raw))
}
}
firstUserCaptured := false
for _, msg := range req.Messages {
switch strings.TrimSpace(msg.Role) {
case "system":
seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(msg.Content))
case "user":
if !firstUserCaptured {
seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content))
firstUserCaptured = true
}
}
}
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
}
func normalizeCompatSeedJSON(v json.RawMessage) string {
if len(v) == 0 {
return ""
}
var tmp any
if err := json.Unmarshal(v, &tmp); err != nil {
return string(v)
}
out, err := json.Marshal(tmp)
if err != nil {
return string(v)
}
return string(out)
}

View File

@@ -0,0 +1,64 @@
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/stretchr/testify/require"
)
func mustRawJSON(t *testing.T, s string) json.RawMessage {
t.Helper()
return json.RawMessage(s)
}
func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
}
func TestDeriveCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
base := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.4",
Messages: []apicompat.ChatMessage{
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
},
}
extended := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.4",
Messages: []apicompat.ChatMessage{
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
{Role: "assistant", Content: mustRawJSON(t, `"Hi there!"`)},
{Role: "user", Content: mustRawJSON(t, `"How are you?"`)},
},
}
k1 := deriveCompatPromptCacheKey(base, "gpt-5.4")
k2 := deriveCompatPromptCacheKey(extended, "gpt-5.4")
require.Equal(t, k1, k2, "cache key should be stable across later turns")
require.NotEmpty(t, k1)
}
func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
req1 := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.4",
Messages: []apicompat.ChatMessage{
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
},
}
req2 := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.4",
Messages: []apicompat.ChatMessage{
{Role: "user", Content: mustRawJSON(t, `"Question B"`)},
},
}
k1 := deriveCompatPromptCacheKey(req1, "gpt-5.4")
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
}

View File

@@ -43,23 +43,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel)
compatPromptCacheInjected = promptCacheKey != ""
}
// 3. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
logFields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
}
if compatPromptCacheInjected {
logFields = append(logFields,
zap.Bool("compat_prompt_cache_key_injected", true),
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
)
}
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)