mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-02 22:42:14 +08:00
Merge pull request #1134 from yasu-dev221/fix/openai-compat-prompt-cache-key
fix(openai): add fallback prompt_cache_key for compat codex OAuth requests
This commit is contained in:
81
backend/internal/service/openai_compat_prompt_cache_key.go
Normal file
81
backend/internal/service/openai_compat_prompt_cache_key.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
const compatPromptCacheKeyPrefix = "compat_cc_"
|
||||||
|
|
||||||
|
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||||
|
switch normalizeCodexModel(strings.TrimSpace(model)) {
|
||||||
|
case "gpt-5.4", "gpt-5.3-codex":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string {
|
||||||
|
if req == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
|
||||||
|
if normalizedModel == "" {
|
||||||
|
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
|
||||||
|
}
|
||||||
|
if normalizedModel == "" {
|
||||||
|
normalizedModel = strings.TrimSpace(req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedParts := []string{"model=" + normalizedModel}
|
||||||
|
if req.ReasoningEffort != "" {
|
||||||
|
seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.ReasoningEffort))
|
||||||
|
}
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
if raw, err := json.Marshal(req.Tools); err == nil {
|
||||||
|
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(req.Functions) > 0 {
|
||||||
|
if raw, err := json.Marshal(req.Functions); err == nil {
|
||||||
|
seedParts = append(seedParts, "functions="+normalizeCompatSeedJSON(raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firstUserCaptured := false
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
switch strings.TrimSpace(msg.Role) {
|
||||||
|
case "system":
|
||||||
|
seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(msg.Content))
|
||||||
|
case "user":
|
||||||
|
if !firstUserCaptured {
|
||||||
|
seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content))
|
||||||
|
firstUserCaptured = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCompatSeedJSON(v json.RawMessage) string {
|
||||||
|
if len(v) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var tmp any
|
||||||
|
if err := json.Unmarshal(v, &tmp); err != nil {
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(tmp)
|
||||||
|
if err != nil {
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustRawJSON(t *testing.T, s string) json.RawMessage {
|
||||||
|
t.Helper()
|
||||||
|
return json.RawMessage(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
|
||||||
|
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
|
||||||
|
base := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
extended := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
|
||||||
|
{Role: "assistant", Content: mustRawJSON(t, `"Hi there!"`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"How are you?"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
k1 := deriveCompatPromptCacheKey(base, "gpt-5.4")
|
||||||
|
k2 := deriveCompatPromptCacheKey(extended, "gpt-5.4")
|
||||||
|
require.Equal(t, k1, k2, "cache key should be stable across later turns")
|
||||||
|
require.NotEmpty(t, k1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
|
||||||
|
req1 := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req2 := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Question B"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
k1 := deriveCompatPromptCacheKey(req1, "gpt-5.4")
|
||||||
|
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
|
||||||
|
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
|
||||||
|
}
|
||||||
@@ -43,23 +43,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
clientStream := chatReq.Stream
|
clientStream := chatReq.Stream
|
||||||
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||||||
|
|
||||||
// 2. Convert to Responses and forward
|
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||||||
|
// derive a stable seed from the final upstream model family.
|
||||||
|
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
|
|
||||||
|
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||||||
|
compatPromptCacheInjected := false
|
||||||
|
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) {
|
||||||
|
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel)
|
||||||
|
compatPromptCacheInjected = promptCacheKey != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Convert to Responses and forward
|
||||||
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||||||
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Model mapping
|
|
||||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
|
||||||
responsesReq.Model = mappedModel
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
logFields := []zap.Field{
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("original_model", originalModel),
|
zap.String("original_model", originalModel),
|
||||||
zap.String("mapped_model", mappedModel),
|
zap.String("mapped_model", mappedModel),
|
||||||
zap.Bool("stream", clientStream),
|
zap.Bool("stream", clientStream),
|
||||||
)
|
}
|
||||||
|
if compatPromptCacheInjected {
|
||||||
|
logFields = append(logFields,
|
||||||
|
zap.Bool("compat_prompt_cache_key_injected", true),
|
||||||
|
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||||||
|
|
||||||
// 4. Marshal Responses request body, then apply OAuth codex transform
|
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||||
responsesBody, err := json.Marshal(responsesReq)
|
responsesBody, err := json.Marshal(responsesReq)
|
||||||
|
|||||||
Reference in New Issue
Block a user