mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-12 10:54:45 +08:00
Merge branch 'test' into release
This commit is contained in:
@@ -308,6 +308,12 @@ type GatewayConfig struct {
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||
// 代理探测响应体读取上限(字节)
|
||||
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
@@ -1059,6 +1065,9 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
|
||||
@@ -1106,7 +1106,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
|
||||
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", parsed["type"])
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||
// 适用于 ACL / 风控等安全敏感场景。
|
||||
func GetTrustedClientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
require.NoError(t, r.SetTrustedProxies(nil))
|
||||
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
c.String(200, GetTrustedClientIP(c))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -1194,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||
if phaseFilter != "upstream" {
|
||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
|
||||
}
|
||||
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
@@ -1208,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||
args = append(args, p)
|
||||
clauses = append(clauses, "platform = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, *filter.GroupID)
|
||||
clauses = append(clauses, "group_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
args = append(args, *filter.AccountID)
|
||||
clauses = append(clauses, "account_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
|
||||
}
|
||||
if phase := phaseFilter; phase != "" {
|
||||
args = append(args, phase)
|
||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
|
||||
}
|
||||
if filter != nil {
|
||||
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||
args = append(args, owner)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||
args = append(args, source)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
}
|
||||
if resolvedFilter != nil {
|
||||
args = append(args, *resolvedFilter)
|
||||
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
@@ -1246,46 +1246,46 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
} else if filter.StatusCodesOther {
|
||||
// "Other" means: status codes not in the common list.
|
||||
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||
args = append(args, pq.Array(known))
|
||||
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
}
|
||||
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||
args = append(args, rid)
|
||||
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||
args = append(args, crid)
|
||||
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||
like := "%" + userQuery + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "u.email ILIKE $"+n)
|
||||
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
|
||||
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
Query: "ACCESS_DENIED",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "e.request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.client_request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified client_request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.error_message ILIKE $") {
|
||||
t.Fatalf("where should include qualified error_message condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
UserQuery: "admin@",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
|
||||
t.Fatalf("where should include EXISTS user email condition: %s", where)
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||
}
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
maxResponseBytes int64
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxResponseBytes := s.maxResponseBytes
|
||||
if maxResponseBytes <= 0 {
|
||||
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBytes {
|
||||
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
|
||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
||||
if err := r.SetTrustedProxies(nil); err != nil {
|
||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||
}
|
||||
if cfg.Server.Mode == "release" {
|
||||
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
|
||||
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Register)
|
||||
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login)
|
||||
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build integration
|
||||
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startAuthRouteRedis(t, ctx)
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
const path = "/api/v1/auth/register"
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "198.51.100.10:23456"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if i <= 5 {
|
||||
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||
continue
|
||||
}
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureAuthRouteDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if authRouteDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||
}
|
||||
|
||||
func authRouteDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func authRouteUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
RegisterAuthRoutes(
|
||||
v1,
|
||||
&handler.Handlers{
|
||||
Auth: &handler.AuthHandler{},
|
||||
Setting: &handler.SettingHandler{},
|
||||
},
|
||||
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
paths := []string{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/login/2fa",
|
||||
"/api/v1/auth/send-verify-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "203.0.113.10:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||
}
|
||||
}
|
||||
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if isOAuth {
|
||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||
if uwErr == nil {
|
||||
respBody = unwrappedBody
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
} else {
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -3,10 +3,15 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
GeminiDebugResponseHeaders: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-RateLimit-Limit": []string{"60"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||
}
|
||||
|
||||
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
@@ -313,7 +313,6 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
|
||||
}
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if result.Matched {
|
||||
log.Warn("OpenAI codex_cli_only 允许官方客户端请求")
|
||||
return
|
||||
}
|
||||
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
||||
@@ -1277,6 +1276,29 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
if account != nil && account.Type == AccountTypeOAuth {
|
||||
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
|
||||
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
|
||||
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: http.StatusForbidden,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: rejectMsg,
|
||||
Detail: rejectReason,
|
||||
})
|
||||
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "forbidden_error",
|
||||
"message": rejectMsg,
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
|
||||
}
|
||||
|
||||
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1396,6 +1418,37 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logOpenAIPassthroughInstructionsRejected(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
reqModel string,
|
||||
rejectReason string,
|
||||
body []byte,
|
||||
) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
accountID := int64(0)
|
||||
accountName := ""
|
||||
accountType := ""
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountName = strings.TrimSpace(account.Name)
|
||||
accountType = strings.TrimSpace(string(account.Type))
|
||||
}
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.String("account_name", accountName),
|
||||
zap.String("account_type", accountType),
|
||||
zap.String("request_model", strings.TrimSpace(reqModel)),
|
||||
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
|
||||
}
|
||||
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
|
||||
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
@@ -1688,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2318,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2877,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
||||
return normalized, changed, nil
|
||||
}
|
||||
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
|
||||
@@ -103,7 +103,7 @@ func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
||||
func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) {
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||
}, nil)
|
||||
|
||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
||||
require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||
}, body)
|
||||
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
|
||||
|
||||
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
|
||||
@@ -211,6 +211,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
|
||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String()))
|
||||
// 其余关键字段保持原值。
|
||||
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
|
||||
@@ -235,6 +236,59 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
require.NotContains(t, body, "\"name\":\"edit\"")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
|
||||
// Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。
|
||||
originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "requires a non-empty instructions field")
|
||||
require.Nil(t, upstream.lastReq)
|
||||
|
||||
require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||
require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||
|
||||
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||
|
||||
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||
}
|
||||
return defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, errors.New("response body is nil")
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||
t.Run("use default when config missing", func(t *testing.T) {
|
||||
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||
})
|
||||
|
||||
t.Run("use configured value", func(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||
t.Run("within limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ok"), body)
|
||||
})
|
||||
|
||||
t.Run("exceeds limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||
require.Nil(t, body)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user