mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a2cf09ee0 | ||
|
|
c6fd88116b | ||
|
|
8f0dbdeaba | ||
|
|
007c09b84e | ||
|
|
73f3c068ef | ||
|
|
9a92fa4a60 | ||
|
|
576af710be | ||
|
|
b5642bd068 | ||
|
|
128f322252 | ||
|
|
17d7e57a2e | ||
|
|
50288e6b01 | ||
|
|
ab3e44e4bd | ||
|
|
61607990c8 | ||
|
|
b65275235f | ||
|
|
e298a71834 | ||
|
|
3f6fa1e3db | ||
|
|
f2c2abe628 | ||
|
|
ff5b467fbe | ||
|
|
8c10941142 | ||
|
|
f5764d8dc6 | ||
|
|
81ca4f12dd | ||
|
|
941c469ab9 | ||
|
|
8fcd819e6f | ||
|
|
9abdaed20c | ||
|
|
eb94342f78 | ||
|
|
d563eb2336 | ||
|
|
3ee6f085db | ||
|
|
7cca69a136 | ||
|
|
093a5a260e | ||
|
|
2c072c0ed6 | ||
|
|
1f39bf8a78 | ||
|
|
fdd8499ffc | ||
|
|
c7f4a649df |
@@ -1 +1 @@
|
||||
0.1.104
|
||||
0.1.105
|
||||
|
||||
@@ -137,7 +137,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
|
||||
@@ -1281,8 +1281,8 @@ func setDefaults() {
|
||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
|
||||
@@ -268,6 +268,14 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
target := a.GetCacheTTLOverrideTarget()
|
||||
out.CacheTTLOverrideTarget = &target
|
||||
}
|
||||
// 自定义 Base URL 中继转发
|
||||
if a.IsCustomBaseURLEnabled() {
|
||||
enabled := true
|
||||
out.CustomBaseURLEnabled = &enabled
|
||||
if customURL := a.GetCustomBaseURL(); customURL != "" {
|
||||
out.CustomBaseURL = &customURL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
|
||||
@@ -198,6 +198,10 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
CustomBaseURLEnabled *bool `json:"custom_base_url_enabled,omitempty"`
|
||||
CustomBaseURL *string `json:"custom_base_url,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
|
||||
@@ -541,6 +541,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
@@ -606,7 +607,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
routingModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
@@ -621,7 +622,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
if defaultModel != "" && defaultModel != routingModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
|
||||
@@ -24,20 +24,18 @@ const (
|
||||
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||
|
||||
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
|
||||
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
|
||||
// Scopes - Setup token (inference only)
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Code Verifier character set (RFC 7636 compliant)
|
||||
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
@@ -147,30 +145,14 @@ func GenerateSessionID() (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (RFC 7636).
|
||||
// Uses 32 random bytes → base64url-no-pad, producing a 43-char verifier.
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
const targetLen = 32
|
||||
charsetLen := len(codeVerifierCharset)
|
||||
limit := 256 - (256 % charsetLen)
|
||||
|
||||
result := make([]byte, 0, targetLen)
|
||||
randBuf := make([]byte, targetLen*2)
|
||||
|
||||
for len(result) < targetLen {
|
||||
if _, err := rand.Read(randBuf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, b := range randBuf {
|
||||
if int(b) < limit {
|
||||
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||
if len(result) >= targetLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64URLEncode(result), nil
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -257,9 +258,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 存在唯一键约束 生成tombstone key 用来释放原key,长度远小于 128,满足 schema 限制
|
||||
tombstoneKey := fmt.Sprintf("__deleted__%d__%d", id, time.Now().UnixNano())
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetKey(tombstoneKey).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -151,6 +151,31 @@ func (s *APIKeyRepoSuite) TestDelete() {
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCreate_AfterSoftDelete_AllowsSameKey() {
|
||||
user := s.mustCreateUser("recreate-after-soft-delete@test.com")
|
||||
const reusedKey = "sk-reuse-after-soft-delete"
|
||||
|
||||
first := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: reusedKey,
|
||||
Name: "First Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, first), "create first key")
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, first.ID), "soft delete first key")
|
||||
|
||||
second := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: reusedKey,
|
||||
Name: "Second Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, second), "create second key with same key")
|
||||
s.Require().NotZero(second.ID)
|
||||
s.Require().NotEqual(first.ID, second.ID, "recreated key should be a new row")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
|
||||
55
backend/internal/repository/internal500_counter_cache.go
Normal file
55
backend/internal/repository/internal500_counter_cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
internal500CounterPrefix = "internal500_count:account:"
|
||||
internal500CounterTTLSeconds = 86400 // 24 小时兜底
|
||||
)
|
||||
|
||||
// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||
// 如果 key 不存在,则创建并设置过期时间
|
||||
var internal500CounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type internal500CounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例
|
||||
func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache {
|
||||
return &internal500CounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
|
||||
result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment internal500 count: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -81,6 +81,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
NewTimeoutCounterCache,
|
||||
NewInternal500CounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
|
||||
@@ -1229,6 +1229,28 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsCustomBaseURLEnabled 检查是否启用自定义 base URL 中继转发
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
func (a *Account) IsCustomBaseURLEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["custom_base_url_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCustomBaseURL 返回自定义中继服务的 base URL
|
||||
func (a *Account) GetCustomBaseURL() string {
|
||||
return a.GetExtraString("custom_base_url")
|
||||
}
|
||||
|
||||
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
|
||||
|
||||
@@ -1866,6 +1866,18 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
if err := s.accountRepo.ClearError(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.accountRepo.ClearModelRateLimits(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
|
||||
86
backend/internal/service/admin_service_clear_error_test.go
Normal file
86
backend/internal/service/admin_service_clear_error_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForClearAccountError struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account
|
||||
clearErrorCalls int
|
||||
clearRateLimitCalls int
|
||||
clearAntigravityCalls int
|
||||
clearModelRateLimitCalls int
|
||||
clearTempUnschedCalls int
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) ClearError(ctx context.Context, id int64) error {
|
||||
r.clearErrorCalls++
|
||||
r.account.Status = StatusActive
|
||||
r.account.ErrorMessage = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
r.clearRateLimitCalls++
|
||||
r.account.RateLimitedAt = nil
|
||||
r.account.RateLimitResetAt = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
r.clearAntigravityCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
r.clearModelRateLimitCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepoStubForClearAccountError) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
r.clearTempUnschedCalls++
|
||||
r.account.TempUnschedulableUntil = nil
|
||||
r.account.TempUnschedulableReason = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *testing.T) {
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
repo := &accountRepoStubForClearAccountError{
|
||||
account: &Account{
|
||||
ID: 31,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
ErrorMessage: "refresh failed",
|
||||
RateLimitResetAt: &resetAt,
|
||||
TempUnschedulableUntil: &until,
|
||||
TempUnschedulableReason: "missing refresh token",
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
updated, err := svc.ClearAccountError(context.Background(), 31)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 1, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Nil(t, updated.RateLimitResetAt)
|
||||
require.Nil(t, updated.TempUnschedulableUntil)
|
||||
require.Empty(t, updated.TempUnschedulableReason)
|
||||
}
|
||||
@@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
usedBaseURL = baseURL
|
||||
allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
@@ -766,10 +767,19 @@ urlFallbackLoop:
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
// 追踪 INTERNAL 500:非匹配的 attempt 清除标记
|
||||
if !isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
allAttemptsInternal500 = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// INTERNAL 500 渐进惩罚:3 次重试全部命中特定 500 时递增计数器并惩罚
|
||||
if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account)
|
||||
}
|
||||
|
||||
// 其他 4xx 错误或重试用尽,直接返回
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -788,6 +798,11 @@ urlFallbackLoop:
|
||||
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
|
||||
}
|
||||
|
||||
// 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry)
|
||||
if resp != nil && resp.StatusCode < 400 {
|
||||
s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID)
|
||||
}
|
||||
|
||||
return &antigravityRetryLoopResult{resp: resp}, nil
|
||||
}
|
||||
|
||||
@@ -862,6 +877,7 @@ type AntigravityGatewayService struct {
|
||||
settingService *SettingService
|
||||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -872,6 +888,7 @@ func NewAntigravityGatewayService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
internal500Cache Internal500CounterCache,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -881,6 +898,7 @@ func NewAntigravityGatewayService(
|
||||
settingService: settingService,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
internal500Cache: internal500Cache,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长
|
||||
const (
|
||||
internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟
|
||||
internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时
|
||||
internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用
|
||||
)
|
||||
|
||||
// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误
|
||||
// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL"
|
||||
func isAntigravityInternalServerError(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusInternalServerError {
|
||||
return false
|
||||
}
|
||||
return gjson.GetBytes(body, "error.code").Int() == 500 &&
|
||||
gjson.GetBytes(body, "error.message").String() == "Internal error encountered." &&
|
||||
gjson.GetBytes(body, "error.status").String() == "INTERNAL"
|
||||
}
|
||||
|
||||
// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚
|
||||
// count=1: temp_unschedulable 10 分钟
|
||||
// count=2: temp_unschedulable 10 小时
|
||||
// count>=3: SetError 永久禁用
|
||||
func (s *AntigravityGatewayService) applyInternal500Penalty(
|
||||
ctx context.Context, prefix string, account *Account, count int64,
|
||||
) {
|
||||
switch {
|
||||
case count >= int64(internal500PenaltyTier3Threshold):
|
||||
reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil {
|
||||
slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_account_disabled",
|
||||
"account_id", account.ID, "account_name", account.Name, "consecutive_count", count)
|
||||
case count == 2:
|
||||
until := time.Now().Add(internal500PenaltyTier2Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier2Duration, "consecutive_count", count)
|
||||
case count == 1:
|
||||
until := time.Now().Add(internal500PenaltyTier1Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier1Duration, "consecutive_count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚
|
||||
func (s *AntigravityGatewayService) handleInternal500RetryExhausted(
|
||||
ctx context.Context, prefix string, account *Account,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID)
|
||||
if err != nil {
|
||||
slog.Error("internal500_counter_increment_failed",
|
||||
"prefix", prefix, "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
s.applyInternal500Penalty(ctx, prefix, account, count)
|
||||
}
|
||||
|
||||
// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器
|
||||
func (s *AntigravityGatewayService) resetInternal500Counter(
|
||||
ctx context.Context, prefix string, accountID int64,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil {
|
||||
slog.Error("internal500_counter_reset_failed",
|
||||
"prefix", prefix, "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock: Internal500CounterCache ---
|
||||
|
||||
type mockInternal500Cache struct {
|
||||
incrementCount int64
|
||||
incrementErr error
|
||||
resetErr error
|
||||
|
||||
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
|
||||
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
|
||||
m.incrementCalls = append(m.incrementCalls, accountID)
|
||||
return m.incrementCount, m.incrementErr
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
|
||||
m.resetCalls = append(m.resetCalls, accountID)
|
||||
return m.resetErr
|
||||
}
|
||||
|
||||
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
|
||||
|
||||
type internal500AccountRepoStub struct {
|
||||
AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用)
|
||||
|
||||
tempUnschedCalls []tempUnschedCall
|
||||
setErrorCalls []setErrorCall
|
||||
}
|
||||
|
||||
type tempUnschedCall struct {
|
||||
accountID int64
|
||||
until time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
type setErrorCall struct {
|
||||
accountID int64
|
||||
reason string
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestIsAntigravityInternalServerError
|
||||
// =============================================================================
|
||||
|
||||
func TestIsAntigravityInternalServerError(t *testing.T) {
|
||||
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.True(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("statusCode 不是 500", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(429, body))
|
||||
require.False(t, isAntigravityInternalServerError(503, body))
|
||||
require.False(t, isAntigravityInternalServerError(200, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 message 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 status 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 code 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("空 body", func(t *testing.T) {
|
||||
require.False(t, isAntigravityInternalServerError(500, []byte{}))
|
||||
require.False(t, isAntigravityInternalServerError(500, nil))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
|
||||
body := []byte(`Internal Server Error`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
|
||||
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestApplyInternal500Penalty
|
||||
// =============================================================================
|
||||
|
||||
func TestApplyInternal500Penalty(t *testing.T) {
|
||||
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(1), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
// until 应在 [before+10m, after+10m] 范围内
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(2), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(3), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
|
||||
})
|
||||
|
||||
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 5, Name: "acc-5"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(5), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
|
||||
})
|
||||
|
||||
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestHandleInternal500RetryExhausted
|
||||
// =============================================================================
|
||||
|
||||
func TestHandleInternal500RetryExhausted(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: nil,
|
||||
}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
// 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
})
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementErr: errors.New("redis connection error"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(2), cache.incrementCalls[0])
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 1,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(3), cache.incrementCalls[0])
|
||||
// tier1: SetTempUnschedulable
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 3,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 4, Name: "acc-4"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestResetInternal500Counter
|
||||
// =============================================================================
|
||||
|
||||
func TestResetInternal500Counter(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: nil,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 1)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{
|
||||
resetErr: errors.New("redis timeout"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 42)
|
||||
})
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(42), cache.resetCalls[0])
|
||||
})
|
||||
|
||||
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 99)
|
||||
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(99), cache.resetCalls[0])
|
||||
})
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{
|
||||
"user-agent": true,
|
||||
"content-type": true,
|
||||
"accept-encoding": true,
|
||||
"x-claude-code-session-id": true,
|
||||
"x-client-request-id": true,
|
||||
}
|
||||
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
@@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取代理URL
|
||||
// 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
|
||||
@@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
} else if account.IsCustomBaseURLEnabled() {
|
||||
customURL := account.GetCustomBaseURL()
|
||||
if customURL == "" {
|
||||
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(customURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
|
||||
}
|
||||
|
||||
clientHeaders := http.Header{}
|
||||
@@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
|
||||
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
|
||||
"url": req.URL.String(),
|
||||
@@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取代理URL
|
||||
// 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
@@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
} else if account.IsCustomBaseURLEnabled() {
|
||||
customURL := account.GetCustomBaseURL()
|
||||
if customURL == "" {
|
||||
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(customURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
|
||||
}
|
||||
|
||||
clientHeaders := http.Header{}
|
||||
@@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c != nil && tokenType == "oauth" {
|
||||
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
|
||||
}
|
||||
@@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
||||
})
|
||||
}
|
||||
|
||||
// buildCustomRelayURL 构建自定义中继转发 URL
|
||||
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
|
||||
func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
|
||||
u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
u += "&proxy=" + url.QueryEscape(proxyURL)
|
||||
}
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
|
||||
@@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{
|
||||
"sec-fetch-mode": "sec-fetch-mode",
|
||||
"accept-encoding": "accept-encoding",
|
||||
"authorization": "authorization",
|
||||
|
||||
// Claude Code 2.1.87+ 新增 header
|
||||
"x-claude-code-session-id": "X-Claude-Code-Session-Id",
|
||||
"x-client-request-id": "x-client-request-id",
|
||||
"content-length": "content-length",
|
||||
}
|
||||
|
||||
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
|
||||
@@ -55,11 +60,14 @@ var headerWireOrder = []string{
|
||||
"authorization",
|
||||
"x-app",
|
||||
"User-Agent",
|
||||
"X-Claude-Code-Session-Id",
|
||||
"content-type",
|
||||
"anthropic-beta",
|
||||
"x-client-request-id",
|
||||
"accept-language",
|
||||
"sec-fetch-mode",
|
||||
"accept-encoding",
|
||||
"content-length",
|
||||
"x-stainless-helper-method",
|
||||
}
|
||||
|
||||
|
||||
11
backend/internal/service/internal500_counter.go
Normal file
11
backend/internal/service/internal500_counter.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
|
||||
type Internal500CounterCache interface {
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error)
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
ResetInternal500Count(ctx context.Context, accountID int64) error
|
||||
}
|
||||
103
backend/internal/service/openai_compat_model.go
Normal file
103
backend/internal/service/openai_compat_model.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
)
|
||||
|
||||
func NormalizeOpenAICompatRequestedModel(model string) string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalized, _, ok := splitOpenAICompatReasoningModel(trimmed)
|
||||
if !ok || normalized == "" {
|
||||
return trimmed
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
originalModel := strings.TrimSpace(req.Model)
|
||||
if originalModel == "" {
|
||||
return
|
||||
}
|
||||
|
||||
normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel)
|
||||
if hasReasoningSuffix && normalizedModel != "" {
|
||||
req.Model = normalizedModel
|
||||
}
|
||||
|
||||
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort)
|
||||
if claudeEffort == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if req.OutputConfig == nil {
|
||||
req.OutputConfig = &apicompat.AnthropicOutputConfig{}
|
||||
}
|
||||
req.OutputConfig.Effort = claudeEffort
|
||||
}
|
||||
|
||||
func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
modelID := trimmed
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") {
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
||||
switch r {
|
||||
case '-', '_', ' ':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1])
|
||||
switch last {
|
||||
case "none", "minimal":
|
||||
case "low", "medium", "high":
|
||||
reasoningEffort = last
|
||||
case "xhigh", "extrahigh":
|
||||
reasoningEffort = "xhigh"
|
||||
default:
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
return normalizeCodexModel(modelID), reasoningEffort, true
|
||||
}
|
||||
|
||||
func openAIReasoningEffortToClaudeOutputEffort(effort string) string {
|
||||
switch strings.TrimSpace(effort) {
|
||||
case "low", "medium", "high":
|
||||
return effort
|
||||
case "xhigh":
|
||||
return "max"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
129
backend/internal/service/openai_compat_model_test.go
Normal file
129
backend/internal/service/openai_compat_model_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"},
|
||||
{name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"},
|
||||
{name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"},
|
||||
{name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOpenAICompatModelNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "gpt-5.4", req.Model)
|
||||
require.NotNil(t, req.OutputConfig)
|
||||
require.Equal(t, "max", req.OutputConfig.Effort)
|
||||
})
|
||||
|
||||
t.Run("explicit output config wins over model suffix", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{
|
||||
Model: "gpt-5.4-xhigh",
|
||||
OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "gpt-5.4", req.Model)
|
||||
require.NotNil(t, req.OutputConfig)
|
||||
require.Equal(t, "low", req.OutputConfig.Effort)
|
||||
})
|
||||
|
||||
t.Run("non openai model is untouched", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "claude-opus-4-6", req.Model)
|
||||
require.Nil(t, req.OutputConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-5.4-xhigh", result.Model)
|
||||
require.Equal(t, "gpt-5.4", result.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4", result.BillingModel)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *result.ReasoningEffort)
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String())
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String())
|
||||
require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String())
|
||||
t.Logf("upstream body: %s", string(upstream.lastBody))
|
||||
t.Logf("response body: %s", rec.Body.String())
|
||||
}
|
||||
@@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
originalModel := anthropicReq.Model
|
||||
applyOpenAICompatModelNormalization(&anthropicReq)
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
@@ -59,7 +60,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
@@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
|
||||
// Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex").
|
||||
// This ensures pricing is always based on the model the user requested.
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
|
||||
@@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
}
|
||||
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
|
||||
@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken != "" {
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: "",
|
||||
IDToken: account.GetCredential("id_token"),
|
||||
ClientID: account.GetCredential("client_id"),
|
||||
Email: account.GetCredential("email"),
|
||||
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
|
||||
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
|
||||
OrganizationID: account.GetCredential("organization_id"),
|
||||
PlanType: account.GetCredential("plan_type"),
|
||||
}
|
||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||
tokenInfo.ExpiresAt = expiresAt.Unix()
|
||||
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
|
||||
}
|
||||
return tokenInfo, nil
|
||||
}
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientRefreshStub struct {
|
||||
refreshCalls int32
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.refreshCalls, 1)
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.refreshCalls, 1)
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
|
||||
client := &openaiOAuthClientRefreshStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
|
||||
expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 77,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "existing-access-token",
|
||||
"expires_at": expiresAt,
|
||||
"client_id": "client-id-1",
|
||||
},
|
||||
}
|
||||
|
||||
info, err := svc.RefreshAccountToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "existing-access-token", info.AccessToken)
|
||||
require.Equal(t, "client-id-1", info.ClientID)
|
||||
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
||||
}
|
||||
@@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error {
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// 检查文件是否过期
|
||||
// 先加载本地文件(确保服务可用),再检查是否需要更新
|
||||
if err := s.loadPricingData(pricingFile); err != nil {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err)
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// 如果配置了哈希URL,通过远程哈希检查是否有更新
|
||||
if s.cfg.Pricing.HashURL != "" {
|
||||
remoteHash, err := s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err)
|
||||
return nil // 已加载本地文件,哈希获取失败不影响启动
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
localHash := s.localHash
|
||||
s.mu.RUnlock()
|
||||
|
||||
if localHash == "" || remoteHash != localHash {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...",
|
||||
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
|
||||
if err := s.downloadPricingData(); err != nil {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 没有哈希URL时,基于文件年龄检查
|
||||
info, err := os.Stat(pricingFile)
|
||||
if err != nil {
|
||||
return s.downloadPricingData()
|
||||
return nil // 已加载本地文件
|
||||
}
|
||||
|
||||
fileAge := time.Since(info.ModTime())
|
||||
@@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error {
|
||||
}
|
||||
}
|
||||
|
||||
// 加载本地文件
|
||||
return s.loadPricingData(pricingFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncWithRemote 与远程同步(基于哈希校验)
|
||||
func (s *PricingService) syncWithRemote() error {
|
||||
pricingFile := s.getPricingFilePath()
|
||||
|
||||
// 计算本地文件哈希
|
||||
localHash, err := s.computeFileHash(pricingFile)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err)
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// 如果配置了哈希URL,从远程获取哈希进行比对
|
||||
if s.cfg.Pricing.HashURL != "" {
|
||||
remoteHash, err := s.fetchRemoteHash()
|
||||
@@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error {
|
||||
return nil // 哈希获取失败不影响正常使用
|
||||
}
|
||||
|
||||
if remoteHash != localHash {
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...")
|
||||
s.mu.RLock()
|
||||
localHash := s.localHash
|
||||
s.mu.RUnlock()
|
||||
|
||||
if localHash == "" || remoteHash != localHash {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...",
|
||||
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed")
|
||||
@@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error {
|
||||
}
|
||||
|
||||
// 没有哈希URL时,基于时间检查
|
||||
pricingFile := s.getPricingFilePath()
|
||||
info, err := os.Stat(pricingFile)
|
||||
if err != nil {
|
||||
return s.downloadPricingData()
|
||||
@@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var expectedHash string
|
||||
// 获取远程哈希(用于同步锚点,不作为完整性校验)
|
||||
var remoteHash string
|
||||
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
|
||||
expectedHash, err = s.fetchRemoteHash()
|
||||
remoteHash, err = s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch remote hash: %w", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
|
||||
if expectedHash != "" {
|
||||
actualHash := sha256.Sum256(body)
|
||||
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
|
||||
return fmt.Errorf("pricing hash mismatch")
|
||||
}
|
||||
// 哈希校验:不匹配时仅告警,不阻止更新
|
||||
// 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件)
|
||||
dataHash := sha256.Sum256(body)
|
||||
dataHashStr := hex.EncodeToString(dataHash[:])
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)",
|
||||
remoteHash[:min(8, len(remoteHash))], dataHashStr[:8])
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
@@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err)
|
||||
}
|
||||
|
||||
// 保存哈希
|
||||
hash := sha256.Sum256(body)
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
// 使用远程哈希作为同步锚点,防止重复下载
|
||||
// 当远程哈希不可用时,回退到数据本身的哈希
|
||||
syncHash := dataHashStr
|
||||
if remoteHash != "" {
|
||||
syncHash = remoteHash
|
||||
}
|
||||
hashFile := s.getHashFilePath()
|
||||
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
|
||||
if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err)
|
||||
}
|
||||
|
||||
@@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error {
|
||||
s.mu.Lock()
|
||||
s.pricingData = data
|
||||
s.lastUpdated = time.Now()
|
||||
s.localHash = hashStr
|
||||
s.localHash = syncHash
|
||||
s.mu.Unlock()
|
||||
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data))
|
||||
@@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
func (s *PricingService) computeFileHash(filePath string) (string, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格(带模糊匹配)
|
||||
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
|
||||
s.mu.RLock()
|
||||
|
||||
@@ -32,8 +32,9 @@ type TokenRefreshService struct {
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
proxyRepo ProxyRepository
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewTokenRefreshService 创建token刷新服务
|
||||
@@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() {
|
||||
|
||||
// Stop 停止刷新服务(可安全多次调用)
|
||||
func (s *TokenRefreshService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
slog.Info("token_refresh.service_stopped")
|
||||
}
|
||||
@@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
"no refresh token available",
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
|
||||
@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
|
||||
updateCredentialsCalls int
|
||||
setErrorCalls int
|
||||
clearTempCalls int
|
||||
setTempUnschedCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
}
|
||||
@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.setTempUnschedCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorStub struct {
|
||||
calls int
|
||||
err error
|
||||
@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 2,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 18,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("no refresh token available"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable")
|
||||
require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state")
|
||||
}
|
||||
|
||||
// TestIsNonRetryableRefreshError 测试不可重试错误判断
|
||||
func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
||||
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
||||
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
||||
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
|
||||
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
|
||||
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
||||
}
|
||||
|
||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" {
|
||||
return trimmedUpstream
|
||||
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return strings.TrimSpace(requestedModel)
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
}
|
||||
|
||||
@@ -865,10 +865,10 @@ rate_limit:
|
||||
pricing:
|
||||
# URL to fetch model pricing data (default: pinned model-price-repo commit)
|
||||
# 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo)
|
||||
remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json"
|
||||
remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json"
|
||||
# Hash verification URL (optional)
|
||||
# 哈希校验 URL(可选)
|
||||
hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256"
|
||||
hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256"
|
||||
# Local data directory for caching
|
||||
# 本地数据缓存目录
|
||||
data_dir: "./data"
|
||||
|
||||
@@ -2245,6 +2245,41 @@
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom Base URL Relay -->
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="customBaseUrlEnabled" class="mt-3">
|
||||
<input
|
||||
v-model="customBaseUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
|
||||
const sessionIdMaskingEnabled = ref(false)
|
||||
const cacheTTLOverrideEnabled = ref(false)
|
||||
const cacheTTLOverrideTarget = ref<string>('5m')
|
||||
const customBaseUrlEnabled = ref(false)
|
||||
const customBaseUrl = ref('')
|
||||
|
||||
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
|
||||
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
|
||||
@@ -3765,6 +3802,8 @@ const resetForm = () => {
|
||||
sessionIdMaskingEnabled.value = false
|
||||
cacheTTLOverrideEnabled.value = false
|
||||
cacheTTLOverrideTarget.value = '5m'
|
||||
customBaseUrlEnabled.value = false
|
||||
customBaseUrl.value = ''
|
||||
allowOverages.value = false
|
||||
antigravityAccountType.value = 'oauth'
|
||||
upstreamBaseUrl.value = ''
|
||||
@@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => {
|
||||
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||
}
|
||||
|
||||
// Add custom base URL settings
|
||||
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
|
||||
extra.custom_base_url_enabled = true
|
||||
extra.custom_base_url = customBaseUrl.value.trim()
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
|
||||
@@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||
}
|
||||
|
||||
// Add custom base URL settings
|
||||
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
|
||||
extra.custom_base_url_enabled = true
|
||||
extra.custom_base_url = customBaseUrl.value.trim()
|
||||
}
|
||||
|
||||
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||
|
||||
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||
|
||||
@@ -1580,6 +1580,41 @@
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom Base URL Relay -->
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="customBaseUrlEnabled" class="mt-3">
|
||||
<input
|
||||
v-model="customBaseUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
|
||||
const sessionIdMaskingEnabled = ref(false)
|
||||
const cacheTTLOverrideEnabled = ref(false)
|
||||
const cacheTTLOverrideTarget = ref<string>('5m')
|
||||
const customBaseUrlEnabled = ref(false)
|
||||
const customBaseUrl = ref('')
|
||||
|
||||
// OpenAI 自动透传开关(OAuth/API Key)
|
||||
const openaiPassthroughEnabled = ref(false)
|
||||
@@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) {
|
||||
sessionIdMaskingEnabled.value = false
|
||||
cacheTTLOverrideEnabled.value = false
|
||||
cacheTTLOverrideTarget.value = '5m'
|
||||
customBaseUrlEnabled.value = false
|
||||
customBaseUrl.value = ''
|
||||
|
||||
// Only applies to Anthropic OAuth/SetupToken accounts
|
||||
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
|
||||
@@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) {
|
||||
cacheTTLOverrideEnabled.value = true
|
||||
cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m'
|
||||
}
|
||||
|
||||
// Load custom base URL setting
|
||||
if (account.custom_base_url_enabled === true) {
|
||||
customBaseUrlEnabled.value = true
|
||||
customBaseUrl.value = account.custom_base_url || ''
|
||||
}
|
||||
}
|
||||
|
||||
function formatTempUnschedKeywords(value: unknown) {
|
||||
@@ -2980,6 +3025,15 @@ const handleSubmit = async () => {
|
||||
delete newExtra.cache_ttl_override_target
|
||||
}
|
||||
|
||||
// Custom base URL relay setting
|
||||
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
|
||||
newExtra.custom_base_url_enabled = true
|
||||
newExtra.custom_base_url = customBaseUrl.value.trim()
|
||||
} else {
|
||||
delete newExtra.custom_base_url_enabled
|
||||
delete newExtra.custom_base_url
|
||||
}
|
||||
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
|
||||
@@ -2318,6 +2318,11 @@ export default {
|
||||
target: 'Target TTL',
|
||||
targetHint: 'Select the TTL tier for billing'
|
||||
},
|
||||
customBaseUrl: {
|
||||
label: 'Custom Relay URL',
|
||||
hint: 'Forward requests to a custom relay service. Proxy URL will be passed as a query parameter.',
|
||||
urlHint: 'Relay service URL (e.g., https://relay.example.com)',
|
||||
},
|
||||
clientAffinity: {
|
||||
label: 'Client Affinity Scheduling',
|
||||
hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching'
|
||||
@@ -4378,6 +4383,7 @@ export default {
|
||||
provider: 'Type',
|
||||
active: 'Active',
|
||||
endpoint: 'Endpoint',
|
||||
bucket: 'Bucket',
|
||||
storagePath: 'Storage Path',
|
||||
capacityUsage: 'Capacity / Used',
|
||||
capacityUnlimited: 'Unlimited',
|
||||
|
||||
@@ -2462,6 +2462,11 @@ export default {
|
||||
target: '目标 TTL',
|
||||
targetHint: '选择计费使用的 TTL 类型'
|
||||
},
|
||||
customBaseUrl: {
|
||||
label: '自定义转发地址',
|
||||
hint: '启用后将请求转发到自定义中继服务,代理地址将作为 URL 参数传递给中继服务',
|
||||
urlHint: '中继服务地址(如 https://relay.example.com)',
|
||||
},
|
||||
clientAffinity: {
|
||||
label: '客户端亲和调度',
|
||||
hint: '启用后,新会话会优先调度到该客户端之前使用过的账号,避免频繁切换账号'
|
||||
@@ -4542,6 +4547,7 @@ export default {
|
||||
provider: '存储类型',
|
||||
active: '生效状态',
|
||||
endpoint: '端点',
|
||||
bucket: '存储桶',
|
||||
storagePath: '存储路径',
|
||||
capacityUsage: '容量 / 已用',
|
||||
capacityUnlimited: '无限制',
|
||||
|
||||
@@ -734,6 +734,10 @@ export interface Account {
|
||||
cache_ttl_override_enabled?: boolean | null
|
||||
cache_ttl_override_target?: string | null
|
||||
|
||||
// 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
custom_base_url_enabled?: boolean | null
|
||||
custom_base_url?: string | null
|
||||
|
||||
// 客户端亲和调度(仅 Anthropic/Antigravity 平台有效)
|
||||
// 启用后新会话会优先调度到客户端之前使用过的账号
|
||||
client_affinity_enabled?: boolean | null
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { computed, onMounted, reactive, ref, watch } from 'vue'
|
||||
import { opsAPI, type OpsRuntimeLogConfig, type OpsSystemLog, type OpsSystemLogSinkHealth } from '@/api/admin/ops'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import { useAppStore } from '@/stores'
|
||||
|
||||
const appStore = useAppStore()
|
||||
@@ -56,6 +57,37 @@ const filters = reactive({
|
||||
q: ''
|
||||
})
|
||||
|
||||
const runtimeLevelOptions = [
|
||||
{ value: 'debug', label: 'debug' },
|
||||
{ value: 'info', label: 'info' },
|
||||
{ value: 'warn', label: 'warn' },
|
||||
{ value: 'error', label: 'error' }
|
||||
]
|
||||
|
||||
const stacktraceLevelOptions = [
|
||||
{ value: 'none', label: 'none' },
|
||||
{ value: 'error', label: 'error' },
|
||||
{ value: 'fatal', label: 'fatal' }
|
||||
]
|
||||
|
||||
const timeRangeOptions = [
|
||||
{ value: '5m', label: '5m' },
|
||||
{ value: '30m', label: '30m' },
|
||||
{ value: '1h', label: '1h' },
|
||||
{ value: '6h', label: '6h' },
|
||||
{ value: '24h', label: '24h' },
|
||||
{ value: '7d', label: '7d' },
|
||||
{ value: '30d', label: '30d' }
|
||||
]
|
||||
|
||||
const filterLevelOptions = [
|
||||
{ value: '', label: '全部' },
|
||||
{ value: 'debug', label: 'debug' },
|
||||
{ value: 'info', label: 'info' },
|
||||
{ value: 'warn', label: 'warn' },
|
||||
{ value: 'error', label: 'error' }
|
||||
]
|
||||
|
||||
const levelBadgeClass = (level: string) => {
|
||||
const v = String(level || '').toLowerCase()
|
||||
if (v === 'error' || v === 'fatal') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
|
||||
@@ -347,20 +379,11 @@ onMounted(async () => {
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="runtimeConfig.level" class="input mt-1">
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
<Select v-model="runtimeConfig.level" class="mt-1" :options="runtimeLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
堆栈阈值
|
||||
<select v-model="runtimeConfig.stacktrace_level" class="input mt-1">
|
||||
<option value="none">none</option>
|
||||
<option value="error">error</option>
|
||||
<option value="fatal">fatal</option>
|
||||
</select>
|
||||
<Select v-model="runtimeConfig.stacktrace_level" class="mt-1" :options="stacktraceLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
采样初始
|
||||
@@ -403,15 +426,7 @@ onMounted(async () => {
|
||||
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-5">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
时间范围
|
||||
<select v-model="filters.time_range" class="input mt-1">
|
||||
<option value="5m">5m</option>
|
||||
<option value="30m">30m</option>
|
||||
<option value="1h">1h</option>
|
||||
<option value="6h">6h</option>
|
||||
<option value="24h">24h</option>
|
||||
<option value="7d">7d</option>
|
||||
<option value="30d">30d</option>
|
||||
</select>
|
||||
<Select v-model="filters.time_range" class="mt-1" :options="timeRangeOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
开始时间(可选)
|
||||
@@ -423,13 +438,7 @@ onMounted(async () => {
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="filters.level" class="input mt-1">
|
||||
<option value="">全部</option>
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
<Select v-model="filters.level" class="mt-1" :options="filterLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
组件
|
||||
|
||||
Reference in New Issue
Block a user