From 7bd1972f945a14fed4bf981541fbcf65ed55e4b2 Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Sun, 15 Mar 2026 22:13:42 +0800 Subject: [PATCH] refactor: migrate all handlers to shared endpoint normalization middleware - Apply InboundEndpointMiddleware to all gateway route groups - Replace normalizedOpenAIInboundEndpoint/normalizedOpenAIUpstreamEndpoint and normalizedGatewayInboundEndpoint/normalizedGatewayUpstreamEndpoint with GetInboundEndpoint/GetUpstreamEndpoint - Remove 4 old constants and 4 old normalization functions (-70 lines) - Migrate existing endpoint normalization test to new API Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- backend/internal/handler/gateway_handler.go | 10 ++- .../internal/handler/gemini_v1beta_handler.go | 4 + .../handler/openai_chat_completions.go | 4 +- ...nai_gateway_endpoint_normalization_test.go | 43 ++++++----- .../handler/openai_gateway_handler.go | 75 ++----------------- .../internal/handler/sora_gateway_handler.go | 4 + backend/internal/server/routes/gateway.go | 14 +++- 7 files changed, 56 insertions(+), 98 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 09652ada..831029c4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -442,6 +442,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) if result.ReasoningEffort == nil { result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) @@ -455,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -757,6 +761,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) if result.ReasoningEffort == nil { result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) @@ -770,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { User: currentAPIKey.User, Account: account, Subscription: currentSubscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -935,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti } if s := c.Query("end_date"); s != "" { if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { - endTime = t.Add(24*time.Hour - time.Second) // end of day + endTime = t.AddDate(0, 0, 1) // half-open range upper bound } } return startTime, endTime diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 9a16ff3a..cfe80911 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, @@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 82b11c10..4db5cadd 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -261,8 +261,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions), - UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, diff --git a/backend/internal/handler/openai_gateway_endpoint_normalization_test.go b/backend/internal/handler/openai_gateway_endpoint_normalization_test.go index 6a055272..0dacd74d 100644 --- a/backend/internal/handler/openai_gateway_endpoint_normalization_test.go +++ b/backend/internal/handler/openai_gateway_endpoint_normalization_test.go @@ -5,42 +5,41 @@ import ( "net/http/httptest" "testing" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) -func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) { +// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the +// unified GetUpstreamEndpoint helper produces the same results as the +// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests. +func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { - name string - path string - fallback string - want string + name string + path string + want string }{ { - name: "responses root maps to responses upstream", - path: "/v1/responses", - fallback: openAIUpstreamEndpointResponses, - want: "/v1/responses", + name: "responses root maps to responses upstream", + path: "/v1/responses", + want: EndpointResponses, }, { - name: "responses compact keeps compact suffix", - path: "/openai/v1/responses/compact", - fallback: openAIUpstreamEndpointResponses, - want: "/v1/responses/compact", + name: "responses compact keeps compact suffix", + path: "/openai/v1/responses/compact", + want: "/v1/responses/compact", }, { - name: "responses nested suffix preserved", - path: "/openai/v1/responses/compact/detail", - fallback: openAIUpstreamEndpointResponses, - want: "/v1/responses/compact/detail", + name: "responses nested suffix preserved", + path: "/openai/v1/responses/compact/detail", + want: "/v1/responses/compact/detail", }, { - name: "non responses path uses fallback", - path: "/v1/messages", - fallback: openAIUpstreamEndpointResponses, - want: "/v1/responses", + name: "non responses path uses platform fallback", + path: "/v1/messages", + want: EndpointResponses, }, } @@ -50,7 +49,7 @@ func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) { c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil) - got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback) + got := GetUpstreamEndpoint(c, service.PlatformOpenAI) require.Equal(t, tt.want, got) }) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index b2aa5c50..c681e61d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -37,13 +37,6 @@ type OpenAIGatewayHandler struct { cfg *config.Config } -const ( - openAIInboundEndpointResponses = "/v1/responses" - openAIInboundEndpointMessages = "/v1/messages" - openAIInboundEndpointChatCompletions = "/v1/chat/completions" - openAIUpstreamEndpointResponses = "/v1/responses" -) - // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -369,8 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses), - UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -747,8 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages), - UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -1246,8 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses), - UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), @@ -1543,62 +1536,6 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) } -func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string { - path := strings.TrimSpace(fallback) - if c != nil { - if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" { - path = fullPath - } else if c.Request != nil && c.Request.URL != nil { - if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" { - path = requestPath - } - } - } - - switch { - case strings.Contains(path, openAIInboundEndpointChatCompletions): - return openAIInboundEndpointChatCompletions - case strings.Contains(path, openAIInboundEndpointMessages): - return openAIInboundEndpointMessages - case strings.Contains(path, openAIInboundEndpointResponses): - return openAIInboundEndpointResponses - default: - return path - } -} - -func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string { - base := strings.TrimSpace(fallback) - if base == "" { - base = openAIUpstreamEndpointResponses - } - base = strings.TrimRight(base, "/") - - if c == nil || c.Request == nil || c.Request.URL == nil { - return base - } - - path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") - if path == "" { - return base - } - - idx := strings.LastIndex(path, "/responses") - if idx < 0 { - return base - } - - suffix := strings.TrimSpace(path[idx+len("/responses"):]) - if suffix == "" || suffix == "/" { - return base - } - if !strings.HasPrefix(suffix, "/") { - return base - } - - return base + suffix -} - func isOpenAIWSUpgradeRequest(r *http.Request) bool { if r == nil { return false diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 06abdf60..dc301ce1 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { @@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index ea40f2f1..fe820830 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -30,6 +30,7 @@ func RegisterGatewayRoutes( soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) + endpointNorm := handler.InboundEndpointMiddleware() // 未分组 Key 拦截中间件(按协议格式区分错误响应) requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter) @@ -40,6 +41,7 @@ func RegisterGatewayRoutes( gateway.Use(bodyLimit) gateway.Use(clientRequestID) gateway.Use(opsErrorLogger) + gateway.Use(endpointNorm) gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(requireGroupAnthropic) { @@ -80,6 +82,7 @@ func RegisterGatewayRoutes( gemini.Use(bodyLimit) gemini.Use(clientRequestID) gemini.Use(opsErrorLogger) + gemini.Use(endpointNorm) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(requireGroupGoogle) { @@ -90,11 +93,11 @@ func RegisterGatewayRoutes( } // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) - r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) - r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) // OpenAI Chat Completions API(不带v1前缀的别名) - r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) @@ -104,6 +107,7 @@ func RegisterGatewayRoutes( antigravityV1.Use(bodyLimit) antigravityV1.Use(clientRequestID) antigravityV1.Use(opsErrorLogger) + antigravityV1.Use(endpointNorm) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) antigravityV1.Use(requireGroupAnthropic) @@ -118,6 +122,7 @@ func RegisterGatewayRoutes( antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(clientRequestID) antigravityV1Beta.Use(opsErrorLogger) + antigravityV1Beta.Use(endpointNorm) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(requireGroupGoogle) @@ -132,6 +137,7 @@ func RegisterGatewayRoutes( soraV1.Use(soraBodyLimit) soraV1.Use(clientRequestID) soraV1.Use(opsErrorLogger) + soraV1.Use(endpointNorm) soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) soraV1.Use(gin.HandlerFunc(apiKeyAuth)) soraV1.Use(requireGroupAnthropic)