refactor: migrate all handlers to shared endpoint normalization middleware

- Apply InboundEndpointMiddleware to all gateway route groups
- Replace normalizedOpenAIInboundEndpoint/normalizedOpenAIUpstreamEndpoint and normalizedGatewayInboundEndpoint/normalizedGatewayUpstreamEndpoint with GetInboundEndpoint/GetUpstreamEndpoint
- Remove 4 old constants and 4 old normalization functions (-70 lines)
- Migrate existing endpoint normalization test to new API

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
Ethan0x0000
2026-03-15 22:13:42 +08:00
parent 2c9dcfe27b
commit 7bd1972f94
7 changed files with 56 additions and 98 deletions

View File

@@ -442,6 +442,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil { if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
@@ -455,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@@ -757,6 +761,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil { if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
@@ -770,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: currentAPIKey.User, User: currentAPIKey.User,
Account: account, Account: account,
Subscription: currentSubscription, Subscription: currentSubscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@@ -935,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
} }
if s := c.Query("end_date"); s != "" { if s := c.Query("end_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
endTime = t.Add(24*time.Hour - time.Second) // end of day endTime = t.AddDate(0, 0, 1) // half-open range upper bound
} }
} }
return startTime, endTime return startTime, endTime

View File

@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result, Result: result,
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,

View File

@@ -261,8 +261,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,

View File

@@ -5,42 +5,41 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) { // TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
// unified GetUpstreamEndpoint helper produces the same results as the
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
tests := []struct { tests := []struct {
name string name string
path string path string
fallback string want string
want string
}{ }{
{ {
name: "responses root maps to responses upstream", name: "responses root maps to responses upstream",
path: "/v1/responses", path: "/v1/responses",
fallback: openAIUpstreamEndpointResponses, want: EndpointResponses,
want: "/v1/responses",
}, },
{ {
name: "responses compact keeps compact suffix", name: "responses compact keeps compact suffix",
path: "/openai/v1/responses/compact", path: "/openai/v1/responses/compact",
fallback: openAIUpstreamEndpointResponses, want: "/v1/responses/compact",
want: "/v1/responses/compact",
}, },
{ {
name: "responses nested suffix preserved", name: "responses nested suffix preserved",
path: "/openai/v1/responses/compact/detail", path: "/openai/v1/responses/compact/detail",
fallback: openAIUpstreamEndpointResponses, want: "/v1/responses/compact/detail",
want: "/v1/responses/compact/detail",
}, },
{ {
name: "non responses path uses fallback", name: "non responses path uses platform fallback",
path: "/v1/messages", path: "/v1/messages",
fallback: openAIUpstreamEndpointResponses, want: EndpointResponses,
want: "/v1/responses",
}, },
} }
@@ -50,7 +49,7 @@ func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
c, _ := gin.CreateTestContext(rec) c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil) c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback) got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
require.Equal(t, tt.want, got) require.Equal(t, tt.want, got)
}) })
} }

View File

@@ -37,13 +37,6 @@ type OpenAIGatewayHandler struct {
cfg *config.Config cfg *config.Config
} }
const (
openAIInboundEndpointResponses = "/v1/responses"
openAIInboundEndpointMessages = "/v1/messages"
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
openAIUpstreamEndpointResponses = "/v1/responses"
)
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
@@ -369,8 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@@ -747,8 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@@ -1246,8 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
@@ -1543,62 +1536,6 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
} }
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
path := strings.TrimSpace(fallback)
if c != nil {
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
path = fullPath
} else if c.Request != nil && c.Request.URL != nil {
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
path = requestPath
}
}
}
switch {
case strings.Contains(path, openAIInboundEndpointChatCompletions):
return openAIInboundEndpointChatCompletions
case strings.Contains(path, openAIInboundEndpointMessages):
return openAIInboundEndpointMessages
case strings.Contains(path, openAIInboundEndpointResponses):
return openAIInboundEndpointResponses
default:
return path
}
}
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
base := strings.TrimSpace(fallback)
if base == "" {
base = openAIUpstreamEndpointResponses
}
base = strings.TrimRight(base, "/")
if c == nil || c.Request == nil || c.Request.URL == nil {
return base
}
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
if path == "" {
return base
}
idx := strings.LastIndex(path, "/responses")
if idx < 0 {
return base
}
suffix := strings.TrimSpace(path[idx+len("/responses"):])
if suffix == "" || suffix == "/" {
return base
}
if !strings.HasPrefix(suffix, "/") {
return base
}
return base + suffix
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool { func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil { if r == nil {
return false return false

View File

@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,

View File

@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
// 未分组 Key 拦截中间件(按协议格式区分错误响应) // 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter) requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
gateway.Use(bodyLimit) gateway.Use(bodyLimit)
gateway.Use(clientRequestID) gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger) gateway.Use(opsErrorLogger)
gateway.Use(endpointNorm)
gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic) gateway.Use(requireGroupAnthropic)
{ {
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
gemini.Use(clientRequestID) gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger) gemini.Use(opsErrorLogger)
gemini.Use(endpointNorm)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle) gemini.Use(requireGroupGoogle)
{ {
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
} }
// OpenAI Responses API不带v1前缀的别名 // OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名 // OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(bodyLimit) antigravityV1.Use(bodyLimit)
antigravityV1.Use(clientRequestID) antigravityV1.Use(clientRequestID)
antigravityV1.Use(opsErrorLogger) antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(endpointNorm)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic) antigravityV1.Use(requireGroupAnthropic)
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(clientRequestID) antigravityV1Beta.Use(clientRequestID)
antigravityV1Beta.Use(opsErrorLogger) antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(endpointNorm)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle) antigravityV1Beta.Use(requireGroupGoogle)
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
soraV1.Use(soraBodyLimit) soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID) soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger) soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth)) soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic) soraV1.Use(requireGroupAnthropic)