diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 7fd24f97..23844508 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -121,7 +121,6 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { var lastFailoverErr *service.UpstreamFailoverError for { - c.Set("openai_chat_completions_fallback_model", "") reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), @@ -139,32 +138,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { - defaultModel := "" - if apiKey.Group != nil { - defaultModel = apiKey.Group.DefaultMappedModel - } - if defaultModel != "" && defaultModel != reqModel { - reqLog.Info("openai_chat_completions.fallback_to_default_model", - zap.String("default_mapped_model", defaultModel), - ) - selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( - c.Request.Context(), - apiKey.GroupID, - "", - sessionHash, - defaultModel, - failedAccountIDs, - service.OpenAIUpstreamTransportAny, - false, - ) - if err == nil && selection != nil { - c.Set("openai_chat_completions_fallback_model", defaultModel) - } - } - if err != nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) - return - } + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return } else { if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) @@ -192,12 +167,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) forwardBody := body if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c0de4476..b5eec393 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -37,16 +37,6 @@ type OpenAIGatewayHandler struct { cfg *config.Config } -func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string { - if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" { - return fallbackModel - } - if apiKey == nil || apiKey.Group == nil { - return "" - } - return strings.TrimSpace(apiKey.Group.DefaultMappedModel) -} - func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { if apiKey == nil || apiKey.Group == nil { return "" diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 0e21dc08..2744e0cc 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -353,30 +353,6 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) { }) } -func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { - t.Run("prefers_explicit_fallback_model", func(t *testing.T) { - apiKey := &service.APIKey{ - Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, - } - require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) - }) - - t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) { - apiKey := &service.APIKey{ - Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, - } - require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, "")) - }) - - t.Run("returns_empty_without_group_default", func(t *testing.T) { - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, "")) - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, "")) - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{ - Group: &service.Group{}, - }, "")) - }) -} - func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) { t.Run("exact_claude_model_override_wins", func(t *testing.T) { apiKey := &service.APIKey{ diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 392b3e0b..cb502a2e 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -226,6 +226,12 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 7.5e-8, SupportsCacheBreakdown: false, } + s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ + InputPricePerToken: 2e-7, + OutputPricePerToken: 1.25e-6, + CacheReadPricePerToken: 2e-8, + SupportsCacheBreakdown: false, + } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -295,6 +301,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { return s.fallbackPrices["gpt-5.5"] case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] + case "gpt-5.4-nano": + return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index b256f1c7..de98b50d 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -38,6 +38,29 @@ var codexModelMap = map[string]string{ "gpt-5.2-medium": "gpt-5.2", "gpt-5.2-high": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5": "gpt-5.4", + "gpt-5-mini": "gpt-5.4", + "gpt-5-nano": "gpt-5.4", + "gpt-5.1": "gpt-5.4", + "gpt-5.1-codex": "gpt-5.3-codex", + "gpt-5.1-codex-max": "gpt-5.3-codex", + "gpt-5.1-codex-mini": "gpt-5.3-codex", + "gpt-5.2-codex": "gpt-5.2", + "codex-mini-latest": "gpt-5.3-codex", + "gpt-5-codex": "gpt-5.3-codex", +} + +var codexVersionModelPrefixes = []struct { + prefix string + target string +}{ + {prefix: "gpt-5.3-codex-spark", target: "gpt-5.3-codex-spark"}, + {prefix: "gpt-5.3-codex", target: "gpt-5.3-codex"}, + {prefix: "gpt-5.4-mini", target: "gpt-5.4-mini"}, + {prefix: "gpt-5.4-nano", target: "gpt-5.4-nano"}, + {prefix: "gpt-5.5", target: "gpt-5.5"}, + {prefix: "gpt-5.4", target: "gpt-5.4"}, + {prefix: "gpt-5.2", target: "gpt-5.2"}, } type codexTransformResult struct { @@ -447,8 +470,19 @@ func normalizeCodexModel(model string) string { if model == "" { return "gpt-5.4" } + if mapped, ok := normalizeKnownCodexModel(model); ok { + return mapped + } + return model +} + +func normalizeKnownCodexModel(model string) (string, bool) { + model = strings.TrimSpace(model) + if model == "" { + return "", false + } if isOpenAIImageGenerationModel(model) { - return model + return model, true } modelID := model @@ -457,41 +491,58 @@ func normalizeCodexModel(model string) string { modelID = parts[len(parts)-1] } - if mapped := getNormalizedCodexModel(modelID); mapped != "" { - return mapped + key := codexModelLookupKey(modelID) + if key == "" { + return "", false } + if mapped := getNormalizedCodexModel(key); mapped != "" { + return mapped, true + } + for _, item := range codexVersionModelPrefixes { + if key == item.prefix { + return item.target, true + } + suffix, ok := strings.CutPrefix(key, item.prefix+"-") + if ok && isKnownCodexModelSuffix(suffix) { + return item.target, true + } + } + return "", false +} - normalized := strings.ToLower(modelID) +func codexModelLookupKey(modelID string) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "" + } + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + return strings.ToLower(strings.Join(strings.Fields(modelID), "-")) +} - if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") { - return "gpt-5.5" - } - if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { - return "gpt-5.4-mini" - } - if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { - return "gpt-5.4" - } - if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { - return "gpt-5.2" - } - if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") { - return "gpt-5.3-codex-spark" - } - if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "codex") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { - return "gpt-5.4" +func isKnownCodexModelSuffix(suffix string) bool { + switch suffix { + case "none", "minimal", "low", "medium", "high", "xhigh": + return true } + return isCodexDateSuffix(suffix) +} - return "gpt-5.4" +func isCodexDateSuffix(suffix string) bool { + parts := strings.Split(suffix, "-") + if len(parts) != 3 || len(parts[0]) != 4 || len(parts[1]) != 2 || len(parts[2]) != 2 { + return false + } + for _, part := range parts { + for _, r := range part { + if r < '0' || r > '9' { + return false + } + } + } + return true } func isCodexSparkModel(model string) bool { @@ -789,18 +840,13 @@ func SupportsVerbosity(model string) bool { } func getNormalizedCodexModel(modelID string) string { - if modelID == "" { + key := codexModelLookupKey(modelID) + if key == "" { return "" } - if mapped, ok := codexModelMap[modelID]; ok { + if mapped, ok := codexModelMap[key]; ok { return mapped } - lower := strings.ToLower(modelID) - for key, value := range codexModelMap { - if strings.ToLower(key) == lower { - return value - } - } return "" } diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index c129a4df..b0d1fa31 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -97,6 +97,42 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } +func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt6","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_unknown_model"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4") + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String()) + require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 76fbb794..4722c82d 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1006,9 +1006,8 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingMode svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} - // When channel did NOT map the model (ChannelMappedModel == OriginalModel), - // billing should use result.BillingModel (the actual model used after group - // DefaultMappedModel resolution), not the unmapped original model. + // 渠道未发生模型映射时,应使用 result.BillingModel 中记录的实际上游计费模型, + // 而不是未映射的原始请求模型。 expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ InputTokens: 20, OutputTokens: 10, diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index f332633c..7cec5212 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -2,44 +2,24 @@ package service import "strings" -// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible -// forwarding. Group-level default mapping only applies when the account itself -// did not match any explicit model_mapping rule. +// resolveOpenAIForwardModel 解析 OpenAI 兼容转发使用的模型。 +// defaultMappedModel 只服务于 /v1/messages 的 Claude 系列显式调度映射, +// 不作为普通 OpenAI 请求的未知模型兜底。 func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { - if defaultMappedModel != "" { + if defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" { return defaultMappedModel } return requestedModel } mappedModel, matched := account.ResolveMappedModel(requestedModel) - if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) { + if !matched && defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" { return defaultMappedModel } return mappedModel } -func isExplicitCodexModel(model string) bool { - model = strings.TrimSpace(model) - if model == "" { - return false - } - if strings.Contains(model, "/") { - parts := strings.Split(model, "/") - model = parts[len(parts)-1] - } - model = strings.ToLower(strings.TrimSpace(model)) - if getNormalizedCodexModel(model) != "" { - return true - } - if strings.HasSuffix(model, "-openai-compact") { - base := strings.TrimSuffix(model, "-openai-compact") - return getNormalizedCodexModel(base) != "" - } - return false -} - // resolveOpenAICompactForwardModel determines the compact-only upstream model // for /responses/compact requests. It never affects normal /responses traffic. // When no compact-specific mapping matches, the input model is returned as-is. diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 4802c089..5c3e1ae0 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -11,7 +11,7 @@ func TestResolveOpenAIForwardModel(t *testing.T) { expectedModel string }{ { - name: "falls back to group default when account has no mapping", + name: "uses messages dispatch default for claude model", account: &Account{ Credentials: map[string]any{}, }, @@ -19,6 +19,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-4o-mini", }, + { + name: "does not fall back to group default for invalid gpt model", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt6", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt6", + }, { name: "preserves explicit gpt-5.4 instead of group default", account: &Account{ @@ -119,14 +128,14 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t * Credentials: map[string]any{}, } - withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) - if withoutDefault != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") + withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") + if withoutDefault != "claude-opus-4-6" { + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withoutDefault, "claude-opus-4-6") } - withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") if withDefault != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4") + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withDefault, "gpt-5.4") } } @@ -205,6 +214,10 @@ func TestNormalizeCodexModel(t *testing.T) { "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", "gpt-image-2": "gpt-image-2", + "gpt-5.4-nano": "gpt-5.4-nano", + "gpt-5.4-nano-high": "gpt-5.4-nano", + "gpt6": "gpt6", + "claude-opus-4-6": "claude-opus-4-6", } for input, expected := range cases { @@ -222,9 +235,21 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { want string }{ { - name: "oauth keeps codex normalization behavior", + name: "oauth preserves unknown non codex model", account: &Account{Type: AccountTypeOAuth}, model: "gemini-3-flash-preview", + want: "gemini-3-flash-preview", + }, + { + name: "oauth preserves invalid gpt model", + account: &Account{Type: AccountTypeOAuth}, + model: "gpt6", + want: "gpt6", + }, + { + name: "oauth normalizes known codex alias", + account: &Account{Type: AccountTypeOAuth}, + model: "gpt-5.4-high", want: "gpt-5.4", }, { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 87a05b14..cc9fc572 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -48,6 +48,49 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc return u.Do(req, proxyURL, accountID, accountConcurrency) } +func TestOpenAIGatewayService_ResponsesUnknownModelDoesNotFallbackToGPT54(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + originalBody := []byte(`{"model":"gpt6","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_unknown_model"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://chatgpt.com/backend-api/codex/responses", upstream.lastReq.URL.String()) + require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String()) + require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, rec.Code >= http.StatusBadRequest) +} + type openAIPassthroughFailoverRepo struct { stubOpenAIAccountRepo rateLimitCalls []time.Time