mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-21 07:04:45 +08:00
Merge tag 'v0.1.90' into merge/upstream-v0.1.90
注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。 - 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略 - Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key - Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示 - 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速 - Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时 - 消除重复的 normalizeAccountIDList,补充新增组件的单元测试 - 清理无用文件和过时文档,精简项目结构 - EmailVerifyView 硬编码英文字符串替换为 i18n 调用 - 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题 - 修复自定义菜单页面管理员视角菜单不生效问题 - 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题 - 修复充值/订阅菜单 icon 显示问题 # Conflicts: # .gitignore # backend/cmd/server/VERSION # backend/ent/group.go # backend/ent/runtime/runtime.go # backend/ent/schema/group.go # backend/go.sum # backend/internal/handler/admin/account_handler.go # backend/internal/handler/admin/dashboard_handler.go # backend/internal/pkg/usagestats/usage_log_types.go # backend/internal/repository/group_repo.go # backend/internal/repository/usage_log_repo.go # backend/internal/server/middleware/security_headers.go # backend/internal/server/router.go # backend/internal/service/account_usage_service.go # backend/internal/service/admin_service_bulk_update_test.go # backend/internal/service/dashboard_service.go # backend/internal/service/gateway_service.go # frontend/src/api/admin/dashboard.ts # frontend/src/components/account/BulkEditAccountModal.vue # frontend/src/components/charts/GroupDistributionChart.vue # frontend/src/components/layout/AppSidebar.vue # frontend/src/i18n/locales/en.ts # frontend/src/i18n/locales/zh.ts # frontend/src/views/admin/GroupsView.vue # frontend/src/views/admin/SettingsView.vue # frontend/src/views/admin/UsageView.vue # frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||
}
|
||||
|
||||
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byID := make(map[string]ClaudeModel, len(models))
|
||||
for _, m := range models {
|
||||
byID[m.ID] = m
|
||||
}
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
}
|
||||
|
||||
for _, id := range requiredIDs {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,9 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
@@ -149,22 +152,26 @@ type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(proxyURL string) *Client {
|
||||
func NewClient(proxyURL string) (*Client, error) {
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
_, parsed, err := proxyurl.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed != nil {
|
||||
transport := &http.Transport{}
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||
return nil, fmt.Errorf("configure proxy: %w", err)
|
||||
}
|
||||
client.Transport = transport
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: client,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
|
||||
@@ -228,8 +228,20 @@ func TestGetTier_两者都为nil(t *testing.T) {
|
||||
// NewClient
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func mustNewClient(t *testing.T, proxyURL string) *Client {
|
||||
t.Helper()
|
||||
client, err := NewClient(proxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient(%q) failed: %v", proxyURL, err)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func TestNewClient_无代理(t *testing.T) {
|
||||
client := NewClient("")
|
||||
client, err := NewClient("")
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient 返回错误: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewClient 返回 nil")
|
||||
}
|
||||
@@ -246,7 +258,10 @@ func TestNewClient_无代理(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient_有代理(t *testing.T) {
|
||||
client := NewClient("http://proxy.example.com:8080")
|
||||
client, err := NewClient("http://proxy.example.com:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient 返回错误: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewClient 返回 nil")
|
||||
}
|
||||
@@ -256,7 +271,10 @@ func TestNewClient_有代理(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient_空格代理(t *testing.T) {
|
||||
client := NewClient(" ")
|
||||
client, err := NewClient(" ")
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient 返回错误: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewClient 返回 nil")
|
||||
}
|
||||
@@ -267,15 +285,13 @@ func TestNewClient_空格代理(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient_无效代理URL(t *testing.T) {
|
||||
// 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容),
|
||||
// 但 ://invalid 会导致解析错误
|
||||
client := NewClient("://invalid")
|
||||
if client == nil {
|
||||
t.Fatal("NewClient 返回 nil")
|
||||
// 无效 URL 应返回 error
|
||||
_, err := NewClient("://invalid")
|
||||
if err == nil {
|
||||
t.Fatal("无效代理 URL 应返回错误")
|
||||
}
|
||||
// 无效 URL 解析失败时,Transport 应保持 nil
|
||||
if client.httpClient.Transport != nil {
|
||||
t.Error("无效代理 URL 时 Transport 应为 nil")
|
||||
if !strings.Contains(err.Error(), "invalid proxy URL") {
|
||||
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -499,7 +515,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
|
||||
if err == nil {
|
||||
t.Fatal("缺少 client_secret 时应返回错误")
|
||||
@@ -602,7 +618,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, err := client.RefreshToken(context.Background(), "refresh-tok")
|
||||
if err == nil {
|
||||
t.Fatal("缺少 client_secret 时应返回错误")
|
||||
@@ -1242,7 +1258,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCodeAssist 失败: %v", err)
|
||||
@@ -1277,7 +1293,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.LoadCodeAssist(context.Background(), "bad-token")
|
||||
if err == nil {
|
||||
t.Fatal("服务器返回 403 时应返回错误")
|
||||
@@ -1300,7 +1316,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.LoadCodeAssist(context.Background(), "token")
|
||||
if err == nil {
|
||||
t.Fatal("无效 JSON 响应应返回错误")
|
||||
@@ -1333,7 +1349,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err)
|
||||
@@ -1361,7 +1377,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.LoadCodeAssist(context.Background(), "token")
|
||||
if err == nil {
|
||||
t.Fatal("所有 URL 都失败时应返回错误")
|
||||
@@ -1377,7 +1393,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
@@ -1441,7 +1457,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchAvailableModels 失败: %v", err)
|
||||
@@ -1496,7 +1512,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj")
|
||||
if err == nil {
|
||||
t.Fatal("服务器返回 403 时应返回错误")
|
||||
@@ -1516,7 +1532,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
|
||||
if err == nil {
|
||||
t.Fatal("无效 JSON 响应应返回错误")
|
||||
@@ -1546,7 +1562,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err)
|
||||
@@ -1574,7 +1590,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
|
||||
if err == nil {
|
||||
t.Fatal("所有 URL 都失败时应返回错误")
|
||||
@@ -1590,7 +1606,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
@@ -1610,7 +1626,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchAvailableModels 失败: %v", err)
|
||||
@@ -1646,7 +1662,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err)
|
||||
@@ -1672,7 +1688,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) {
|
||||
|
||||
withMockBaseURLs(t, []string{server1.URL, server2.URL})
|
||||
|
||||
client := NewClient("")
|
||||
client := mustNewClient(t, "")
|
||||
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err)
|
||||
|
||||
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
|
||||
@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
|
||||
@@ -52,4 +52,7 @@ const (
|
||||
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
|
||||
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
|
||||
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
|
||||
|
||||
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
|
||||
ClaudeCodeVersion Key = "ctx_claude_code_version"
|
||||
)
|
||||
|
||||
@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
|
||||
md := map[string]string{"k": "v"}
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
|
||||
|
||||
code, body := ToHTTP(appErr)
|
||||
require.Equal(t, http.StatusBadRequest, code)
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
md["k"] = "changed"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
appErr.Metadata["k"] = "changed-again"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
}
|
||||
|
||||
@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
body = Status{
|
||||
Code: appErr.Code,
|
||||
Reason: appErr.Reason,
|
||||
Message: appErr.Message,
|
||||
}
|
||||
if appErr.Metadata != nil {
|
||||
body.Metadata = make(map[string]string, len(appErr.Metadata))
|
||||
for k, v := range appErr.Metadata {
|
||||
body.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
return int(appErr.Code), body
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ package httpclient
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
defaultMaxIdleConns = 100 // 最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
|
||||
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
|
||||
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
|
||||
)
|
||||
|
||||
// Options 定义共享 HTTP 客户端的构建参数
|
||||
@@ -40,7 +41,6 @@ type Options struct {
|
||||
Timeout time.Duration // 请求总超时时间
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
|
||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
|
||||
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
|
||||
|
||||
@@ -53,6 +53,9 @@ type Options struct {
|
||||
// sharedClients 存储按配置参数缓存的 http.Client 实例
|
||||
var sharedClients sync.Map
|
||||
|
||||
// 允许测试替换校验函数,生产默认指向真实实现。
|
||||
var validateResolvedIP = urlvalidator.ValidateResolvedIP
|
||||
|
||||
// GetClient 返回共享的 HTTP 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
|
||||
@@ -84,7 +87,7 @@ func buildClient(opts Options) (*http.Client, error) {
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
|
||||
rt = &validatedTransport{base: transport}
|
||||
rt = newValidatedTransport(transport)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
@@ -116,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
|
||||
}
|
||||
|
||||
proxyURL := strings.TrimSpace(opts.ProxyURL)
|
||||
if proxyURL == "" {
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(proxyURL)
|
||||
_, parsed, err := proxyurl.Parse(opts.ProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||
return nil, err
|
||||
@@ -134,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
}
|
||||
|
||||
func buildClientKey(opts Options) string {
|
||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
|
||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.ResponseHeaderTimeout.String(),
|
||||
opts.InsecureSkipVerify,
|
||||
opts.ProxyStrict,
|
||||
opts.ValidateResolvedIP,
|
||||
opts.AllowPrivateHosts,
|
||||
opts.MaxIdleConns,
|
||||
@@ -149,17 +149,56 @@ func buildClientKey(opts Options) string {
|
||||
}
|
||||
|
||||
type validatedTransport struct {
|
||||
base http.RoundTripper
|
||||
base http.RoundTripper
|
||||
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
|
||||
return &validatedTransport{
|
||||
base: base,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := t.validatedHosts.Load(host)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expireAt, ok := raw.(time.Time)
|
||||
if !ok {
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
if now.Before(expireAt) {
|
||||
return true
|
||||
}
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req != nil && req.URL != nil {
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
|
||||
if host != "" {
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
now := time.Now()
|
||||
if t != nil && t.now != nil {
|
||||
now = t.now()
|
||||
}
|
||||
if !t.isValidatedHost(host, now) {
|
||||
if err := validateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
|
||||
}
|
||||
}
|
||||
}
|
||||
if t == nil || t.base == nil {
|
||||
return nil, fmt.Errorf("validated transport base is nil")
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(host string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
require.Equal(t, "api.openai.com", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730000000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(_ string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730001000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
now = now.Add(validatedHostTTL + time.Second)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
expectedErr := errors.New("dns rebinding rejected")
|
||||
validateResolvedIP = func(_ string) error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
|
||||
})
|
||||
|
||||
transport := newValidatedTransport(base)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
37
backend/internal/pkg/httputil/body.go
Normal file
37
backend/internal/pkg/httputil/body.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
requestBodyReadInitCap = 512
|
||||
requestBodyReadMaxInitCap = 1 << 20
|
||||
)
|
||||
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
capHint := requestBodyReadInitCap
|
||||
if req.ContentLength > 0 {
|
||||
switch {
|
||||
case req.ContentLength < int64(requestBodyReadInitCap):
|
||||
capHint = requestBodyReadInitCap
|
||||
case req.ContentLength > int64(requestBodyReadMaxInitCap):
|
||||
capHint = requestBodyReadMaxInitCap
|
||||
default:
|
||||
capHint = int(req.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 0, capHint))
|
||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// CompiledIPRules 表示预编译的 IP 匹配规则。
|
||||
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
|
||||
type CompiledIPRules struct {
|
||||
CIDRs []*net.IPNet
|
||||
IPs []net.IP
|
||||
PatternCount int
|
||||
}
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
@@ -84,6 +92,53 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
|
||||
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
|
||||
func CompileIPRules(patterns []string) *CompiledIPRules {
|
||||
compiled := &CompiledIPRules{
|
||||
CIDRs: make([]*net.IPNet, 0, len(patterns)),
|
||||
IPs: make([]net.IP, 0, len(patterns)),
|
||||
PatternCount: len(patterns),
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
normalized := strings.TrimSpace(pattern)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(normalized, "/") {
|
||||
_, cidr, err := net.ParseCIDR(normalized)
|
||||
if err != nil || cidr == nil {
|
||||
continue
|
||||
}
|
||||
compiled.CIDRs = append(compiled.CIDRs, cidr)
|
||||
continue
|
||||
}
|
||||
parsedIP := net.ParseIP(normalized)
|
||||
if parsedIP == nil {
|
||||
continue
|
||||
}
|
||||
compiled.IPs = append(compiled.IPs, parsedIP)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
|
||||
if parsedIP == nil || rules == nil {
|
||||
return false
|
||||
}
|
||||
for _, cidr := range rules.CIDRs {
|
||||
if cidr.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, ruleIP := range rules.IPs {
|
||||
if parsedIP.Equal(ruleIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||
return CheckIPRestrictionWithCompiledRules(
|
||||
clientIP,
|
||||
CompileIPRules(whitelist),
|
||||
CompileIPRules(blacklist),
|
||||
)
|
||||
}
|
||||
|
||||
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
|
||||
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
|
||||
// 规范化 IP
|
||||
clientIP = normalizeIP(clientIP)
|
||||
if clientIP == "" {
|
||||
return false, "access denied"
|
||||
}
|
||||
parsedIP := net.ParseIP(clientIP)
|
||||
if parsedIP == nil {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 1. 检查黑名单
|
||||
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
|
||||
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
|
||||
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
|
||||
blacklist := CompileIPRules([]string{"10.1.1.1"})
|
||||
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
|
||||
require.True(t, allowed)
|
||||
require.Equal(t, "", reason)
|
||||
|
||||
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
|
||||
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
|
||||
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -42,15 +43,19 @@ type LogEvent struct {
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
global atomic.Pointer[zap.Logger]
|
||||
sugar atomic.Pointer[zap.SugaredLogger]
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
currentSink atomic.Value // sinkState
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
type sinkState struct {
|
||||
sink Sink
|
||||
}
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
prev := global.Load()
|
||||
global.Store(zl)
|
||||
sugar.Store(zl.Sugar())
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
if global.Load() == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
currentSink.Store(sinkState{sink: sink})
|
||||
}
|
||||
|
||||
func loadSink() Sink {
|
||||
v := currentSink.Load()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
state, ok := v.(sinkState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return state.sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
if l := global.Load(); l != nil {
|
||||
return l
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
if s := sugar.Load(); s != nil {
|
||||
return s
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
l := global.Load()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only handle sink forwarding — the inner cores write via their own
|
||||
// Write methods (added to CheckedEntry by s.core.Check above).
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
initialized := global.Load() != nil
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
|
||||
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
h.logger.Error(record.Message, fields...)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
h.logger.Warn(record.Message, fields...)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
h.logger.Debug(record.Message, fields...)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
h.logger.Info(record.Message, fields...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
@@ -36,10 +36,18 @@ const (
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
||||
OAuthPlatformOpenAI = "openai"
|
||||
// OAuthPlatformSora uses Sora OAuth client.
|
||||
OAuthPlatformSora = "sora"
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
|
||||
}
|
||||
|
||||
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
|
||||
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
if codexFlow {
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
||||
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
|
||||
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
|
||||
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case OAuthPlatformSora:
|
||||
return ClientID, false
|
||||
default:
|
||||
return ClientID, true
|
||||
}
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
|
||||
t.Fatalf("codex flow mismatch: got=%q want=true", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
|
||||
// 但不启用 codex_cli_simplified_flow。
|
||||
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "" {
|
||||
t.Fatalf("codex flow should be empty for sora, got=%q", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
66
backend/internal/pkg/proxyurl/parse.go
Normal file
66
backend/internal/pkg/proxyurl/parse.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连)
|
||||
//
|
||||
// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。
|
||||
// 直接使用 url.Parse 处理代理 URL 是被禁止的。
|
||||
// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败,
|
||||
// 而不是在运行时静默回退到直连(产生 IP 关联风险)。
|
||||
package proxyurl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// allowedSchemes 代理协议白名单
|
||||
var allowedSchemes = map[string]bool{
|
||||
"http": true,
|
||||
"https": true,
|
||||
"socks5": true,
|
||||
"socks5h": true,
|
||||
}
|
||||
|
||||
// Parse 解析并验证代理 URL。
|
||||
//
|
||||
// 语义:
|
||||
// - 空字符串 → ("", nil, nil),表示直连
|
||||
// - 非空且有效 → (trimmed, *url.URL, nil)
|
||||
// - 非空但无效 → ("", nil, error),fail-fast 不回退
|
||||
//
|
||||
// 验证规则:
|
||||
// - TrimSpace 后为空视为直连
|
||||
// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露)
|
||||
// - Host 为空返回 error(用 Redacted() 脱敏)
|
||||
// - Scheme 必须为 http/https/socks5/socks5h
|
||||
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
|
||||
func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
|
||||
trimmed = strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
parsed, err = url.Parse(trimmed)
|
||||
if err != nil {
|
||||
// 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据)
|
||||
return "", nil, fmt.Errorf("invalid proxy URL: %v", err)
|
||||
}
|
||||
|
||||
if parsed.Host == "" || parsed.Hostname() == "" {
|
||||
return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted())
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if !allowedSchemes[scheme] {
|
||||
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme)
|
||||
}
|
||||
|
||||
// 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。
|
||||
// Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS,
|
||||
// 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。
|
||||
if scheme == "socks5" {
|
||||
parsed.Scheme = "socks5h"
|
||||
trimmed = parsed.String()
|
||||
}
|
||||
|
||||
return trimmed, parsed, nil
|
||||
}
|
||||
215
backend/internal/pkg/proxyurl/parse_test.go
Normal file
215
backend/internal/pkg/proxyurl/parse_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package proxyurl
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse_空字符串直连(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("")
|
||||
if err != nil {
|
||||
t.Fatalf("空字符串应直连: %v", err)
|
||||
}
|
||||
if trimmed != "" {
|
||||
t.Errorf("trimmed 应为空: got %q", trimmed)
|
||||
}
|
||||
if parsed != nil {
|
||||
t.Errorf("parsed 应为 nil: got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_空白字符串直连(t *testing.T) {
|
||||
trimmed, parsed, err := Parse(" ")
|
||||
if err != nil {
|
||||
t.Fatalf("空白字符串应直连: %v", err)
|
||||
}
|
||||
if trimmed != "" {
|
||||
t.Errorf("trimmed 应为空: got %q", trimmed)
|
||||
}
|
||||
if parsed != nil {
|
||||
t.Errorf("parsed 应为 nil: got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_有效HTTP代理(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("http://proxy.example.com:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 HTTP 代理应成功: %v", err)
|
||||
}
|
||||
if trimmed != "http://proxy.example.com:8080" {
|
||||
t.Errorf("trimmed 不匹配: got %q", trimmed)
|
||||
}
|
||||
if parsed == nil {
|
||||
t.Fatal("parsed 不应为 nil")
|
||||
}
|
||||
if parsed.Host != "proxy.example.com:8080" {
|
||||
t.Errorf("Host 不匹配: got %q", parsed.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_有效HTTPS代理(t *testing.T) {
|
||||
_, parsed, err := Parse("https://proxy.example.com:443")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 HTTPS 代理应成功: %v", err)
|
||||
}
|
||||
if parsed.Scheme != "https" {
|
||||
t.Errorf("Scheme 不匹配: got %q", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("socks5://127.0.0.1:1080")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 SOCKS5 代理应成功: %v", err)
|
||||
}
|
||||
// socks5 自动升级为 socks5h,确保 DNS 由代理端解析
|
||||
if trimmed != "socks5h://127.0.0.1:1080" {
|
||||
t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed)
|
||||
}
|
||||
if parsed.Scheme != "socks5h" {
|
||||
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_无效URL(t *testing.T) {
|
||||
_, _, err := Parse("://invalid")
|
||||
if err == nil {
|
||||
t.Fatal("无效 URL 应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid proxy URL") {
|
||||
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_缺少Host(t *testing.T) {
|
||||
_, _, err := Parse("http://")
|
||||
if err == nil {
|
||||
t.Fatal("缺少 host 应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing host") {
|
||||
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_不支持的Scheme(t *testing.T) {
|
||||
_, _, err := Parse("ftp://proxy.example.com:21")
|
||||
if err == nil {
|
||||
t.Fatal("不支持的 scheme 应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
|
||||
t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_含密码URL脱敏(t *testing.T) {
|
||||
// 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h
|
||||
trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080")
|
||||
if err != nil {
|
||||
t.Fatalf("含密码的有效 URL 应成功: %v", err)
|
||||
}
|
||||
if trimmed == "" || parsed == nil {
|
||||
t.Fatal("应返回非空结果")
|
||||
}
|
||||
if parsed.Scheme != "socks5h" {
|
||||
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "socks5h://") {
|
||||
t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed)
|
||||
}
|
||||
if parsed.User == nil {
|
||||
t.Error("升级后应保留 UserInfo")
|
||||
}
|
||||
|
||||
// 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径)
|
||||
_, _, err = Parse("http://user:secret_password@:0/")
|
||||
if err == nil {
|
||||
t.Fatal("缺少 host 应返回错误")
|
||||
}
|
||||
if strings.Contains(err.Error(), "secret_password") {
|
||||
t.Error("错误信息不应包含明文密码")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing host") {
|
||||
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_带空白的有效URL(t *testing.T) {
|
||||
trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ")
|
||||
if err != nil {
|
||||
t.Fatalf("带空白的有效 URL 应成功: %v", err)
|
||||
}
|
||||
if trimmed != "http://proxy.example.com:8080" {
|
||||
t.Errorf("trimmed 应去除空白: got %q", trimmed)
|
||||
}
|
||||
if parsed == nil {
|
||||
t.Fatal("parsed 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_Scheme大小写不敏感(t *testing.T) {
|
||||
// 大写 SOCKS5 应被接受并升级为 socks5h
|
||||
trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080")
|
||||
if err != nil {
|
||||
t.Fatalf("大写 SOCKS5 应被接受: %v", err)
|
||||
}
|
||||
if parsed.Scheme != "socks5h" {
|
||||
t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme)
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "socks5h://") {
|
||||
t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed)
|
||||
}
|
||||
|
||||
// 大写 HTTP 应被接受(不变)
|
||||
_, _, err = Parse("HTTP://proxy.example.com:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("大写 HTTP 应被接受: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_带认证的有效代理(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("带认证的代理 URL 应成功: %v", err)
|
||||
}
|
||||
if parsed.User == nil {
|
||||
t.Error("应保留 UserInfo")
|
||||
}
|
||||
if trimmed != "http://user:pass@proxy.example.com:8080" {
|
||||
t.Errorf("trimmed 不匹配: got %q", trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_IPv6地址(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("http://[::1]:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("IPv6 代理 URL 应成功: %v", err)
|
||||
}
|
||||
if parsed.Hostname() != "::1" {
|
||||
t.Errorf("Hostname 不匹配: got %q", parsed.Hostname())
|
||||
}
|
||||
if trimmed != "http://[::1]:8080" {
|
||||
t.Errorf("trimmed 不匹配: got %q", trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_SOCKS5H保持不变(t *testing.T) {
|
||||
trimmed, parsed, err := Parse("socks5h://proxy.local:1080")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 SOCKS5H 代理应成功: %v", err)
|
||||
}
|
||||
// socks5h 不需要升级,应保持原样
|
||||
if trimmed != "socks5h://proxy.local:1080" {
|
||||
t.Errorf("trimmed 不应变化: got %q", trimmed)
|
||||
}
|
||||
if parsed.Scheme != "socks5h" {
|
||||
t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_无Scheme裸地址(t *testing.T) {
|
||||
// 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空
|
||||
_, _, err := Parse("proxy.example.com:8080")
|
||||
if err == nil {
|
||||
t.Fatal("无 scheme 的裸地址应返回错误")
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,11 @@
|
||||
//
|
||||
// 支持的代理协议:
|
||||
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
|
||||
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
|
||||
// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS)
|
||||
// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐)
|
||||
//
|
||||
// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://,
|
||||
// 确保 DNS 也由代理端解析,防止 DNS 泄漏。
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
@@ -20,7 +24,8 @@ import (
|
||||
//
|
||||
// 支持的协议:
|
||||
// - http/https: 设置 transport.Proxy
|
||||
// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
|
||||
// - socks5: 设置 transport.DialContext(客户端本地解析 DNS)
|
||||
// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐)
|
||||
//
|
||||
// 参数:
|
||||
// - transport: 需要配置的 http.Transport
|
||||
|
||||
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
|
||||
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
"tls_vers_max", spec.TLSVersMax,
|
||||
"tls_vers_min", spec.TLSVersMin)
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
|
||||
@@ -80,12 +80,12 @@ type ModelStat struct {
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
type GroupStat struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
@@ -149,10 +149,13 @@ type UsageLogFilters struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
|
||||
ExactTotal bool
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
|
||||
Reference in New Issue
Block a user