Compare commits

...

33 Commits

Author SHA1 Message Date
Wesley Liddick
6a2cf09ee0 Merge pull request #1349 from touwaeriol/feat/antigravity-internal500-penalty
feat(antigravity): progressive penalty for consecutive INTERNAL 500 errors
2026-03-30 15:54:04 +08:00
Wesley Liddick
c6fd88116b Merge pull request #1354 from wucm667/fix/billing-use-requested-model
fix(billing): 计费始终使用用户请求的原始模型,而非映射后的上游模型
2026-03-30 15:52:31 +08:00
Wesley Liddick
8f0dbdeaba Merge pull request #1343 from yilinyo/fix/api-key-unique-conflict-after-soft-delete
fix(api-key):软删除apikey后key没有被释放后续无法再自定义相同的key
2026-03-30 15:47:28 +08:00
Wesley Liddick
007c09b84e Merge pull request #1338 from LvyuanW/fix/safari-ops-log-select
fix(admin): fix Safari system log select height
2026-03-30 15:45:35 +08:00
Wesley Liddick
73f3c068ef Merge pull request #1344 from 7836246/fix/i18n-sora-storage-missing-keys
fix(i18n): 修复 Sora 存储配置页面表格列头「存储桶」翻译缺失
2026-03-30 15:45:03 +08:00
Wesley Liddick
9a92fa4a60 Merge pull request #1370 from YanzheL/fix/1320-openai-messages-gpt54-xhigh
fix(gateway): normalize gpt-5.4-xhigh for /v1/messages
2026-03-30 15:44:34 +08:00
Wesley Liddick
576af710be Merge pull request #1352 from StarryKira/feat/add-file-upload-oauth-scope
Feat/add file upload oauth scope
2026-03-30 15:41:18 +08:00
Wesley Liddick
b5642bd068 Merge pull request #1377 from DaydreamCoding/fix/lifecycle-stop-duplicate-close
fix(lifecycle): TokenRefreshService Stop() 防重复 close
2026-03-30 15:38:39 +08:00
Wesley Liddick
128f322252 Merge pull request #1376 from weak-fox/fix/privacy-without-refresh-token
修复缺少 refresh_token 时被临时停调度
2026-03-30 15:38:27 +08:00
Wesley Liddick
17d7e57a2e Merge pull request #1375 from weak-fox/fix/batch-reset-temp-unsched
修复重置状态时未清理临时停调度
2026-03-30 15:37:58 +08:00
shaw
50288e6b01 fix: 修复模型定价文件更新url 2026-03-30 15:36:53 +08:00
shaw
ab3e44e4bd fix: 适配X-Claude-Code-Session-Id头 2026-03-30 11:43:07 +08:00
QTom
61607990c8 fix(lifecycle): TokenRefreshService Stop() 防重复 close
使用 sync.Once 包裹 close(stopCh),避免多次调用 Stop() 时
触发 panic: close of closed channel。
2026-03-30 10:33:06 +08:00
shaw
b65275235f feat: Anthropic oauth/setup-token账号支持自定义转发URL 2026-03-30 09:10:57 +08:00
weak-fox
e298a71834 fix: clear temp unsched when resetting account status 2026-03-30 00:22:02 +08:00
weak-fox
3f6fa1e3db fix: avoid temp unsched when refresh token is missing 2026-03-30 00:21:51 +08:00
YanzheL
f2c2abe628 fix(openai): keep xhigh normalization scoped to messages 2026-03-29 21:09:19 +08:00
YanzheL
ff5b467fbe fix(handler): normalize compat model for message routing 2026-03-29 20:53:14 +08:00
YanzheL
8c10941142 fix(openai): normalize gpt-5.4-xhigh compat mapping 2026-03-29 20:52:29 +08:00
wucm667
f5764d8dc6 fix(billing): 计费始终使用用户请求的原始模型,而非映射后的上游模型
当账号配置了模型映射(如 claude-sonnet-4-6 → glm-5.0)时,系统错误地
使用映射后的上游模型名计算费用。由于上游模型(如 glm-5.0)在定价系统中
没有价格配置,导致计费失败后被静默置为 0,用户不被扣费。

修改 forwardResultBillingModel 优先返回请求模型名,并移除 OpenAI 路径
中 BillingModel 字段对计费模型的覆盖逻辑。
2026-03-28 16:22:06 +08:00
Elysia
81ca4f12dd 修复误删的url 2026-03-28 00:55:55 +08:00
Elysia
941c469ab9 fix: use standard PKCE code verifier generation
Replace charset→base64url double-encoding with standard random
bytes→base64url approach to match official client behavior and avoid
risk control detection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 00:47:31 +08:00
Elysia
8fcd819e6f feat: add user:file_upload OAuth scope
Align OAuth scopes with upstream Claude Code client which now includes
the user:file_upload scope for file upload support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 00:40:36 +08:00
erio
9abdaed20c style: gofmt antigravity_internal500_penalty.go 2026-03-27 20:18:07 +08:00
erio
eb94342f78 chore: adjust internal500 penalty durations to 30m / 2h 2026-03-27 20:11:24 +08:00
erio
d563eb2336 test: add unit tests for INTERNAL 500 progressive penalty
Cover isAntigravityInternalServerError body matching,
applyInternal500Penalty tier escalation, handleInternal500RetryExhausted
nil-safety and error handling, and resetInternal500Counter paths.
2026-03-27 20:11:24 +08:00
erio
3ee6f085db refactor: extract internal500 penalty logic to dedicated file
Move constants, detection, and penalty functions from
antigravity_gateway_service.go to antigravity_internal500_penalty.go.
Fix gofmt alignment and replace hardcoded duration strings with
constant references.
2026-03-27 20:11:24 +08:00
erio
7cca69a136 fix: move internal500 counter reset to cover all success paths
Move the reset logic after urlFallbackLoop so it covers both direct
success and smart retry (429/503) success paths.
2026-03-27 20:11:24 +08:00
erio
093a5a260e feat(antigravity): progressive penalty for consecutive INTERNAL 500 errors
When an antigravity account returns 500 "Internal error encountered."
on all 3 retry attempts, increment a Redis counter and apply escalating
penalties:
- 1st round: temp unschedulable 10 minutes
- 2nd round: temp unschedulable 10 hours
- 3rd round: permanently mark as error

Counter resets on any successful response (< 400).
2026-03-27 20:11:24 +08:00
小海
2c072c0ed6 fix(i18n): add missing bucket column translation key for Sora S3 storage settings
The `admin.settings.soraS3.columns.bucket` key was used in
DataManagementView.vue but missing from both en.ts and zh.ts locale
files, causing the raw translation key to be displayed as a column
header instead of the localized text.
2026-03-27 16:44:14 +08:00
YilinMacAir
1f39bf8a78 fix:修复由于数据库唯一键导致软删除apikey后key没有被释放后续无法再自定义相同的key 2026-03-27 16:37:10 +08:00
github-actions[bot]
fdd8499ffc chore: sync VERSION to 0.1.105 [skip ci] 2026-03-27 08:04:27 +00:00
Wang Lvyuan
c7f4a649df fix(admin): use custom select for ops log filters 2026-03-27 14:07:12 +08:00
38 changed files with 1321 additions and 116 deletions

View File

@@ -1 +1 @@
0.1.104
0.1.105

View File

@@ -137,7 +137,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)

View File

@@ -1281,8 +1281,8 @@ func setDefaults() {
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)

View File

@@ -268,6 +268,14 @@ func AccountFromServiceShallow(a *service.Account) *Account {
target := a.GetCacheTTLOverrideTarget()
out.CacheTTLOverrideTarget = &target
}
// 自定义 Base URL 中继转发
if a.IsCustomBaseURLEnabled() {
enabled := true
out.CustomBaseURLEnabled = &enabled
if customURL := a.GetCustomBaseURL(); customURL != "" {
out.CustomBaseURL = &customURL
}
}
}
// 提取账号配额限制apikey / bedrock 类型有效)

View File

@@ -198,6 +198,10 @@ type Account struct {
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
// 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效)
CustomBaseURLEnabled *bool `json:"custom_base_url_enabled,omitempty"`
CustomBaseURL *string `json:"custom_base_url,omitempty"`
// API Key 账号配额限制
QuotaLimit *float64 `json:"quota_limit,omitempty"`
QuotaUsed *float64 `json:"quota_used,omitempty"`

View File

@@ -541,6 +541,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
@@ -606,7 +607,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
reqModel,
routingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
@@ -621,7 +622,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
if defaultModel != "" && defaultModel != routingModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)

View File

@@ -24,20 +24,18 @@ const (
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
@@ -147,30 +145,14 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier using character set method
// GenerateCodeVerifier generates a PKCE code verifier (RFC 7636).
// Uses 32 random bytes → base64url-no-pad, producing a 43-char verifier.
func GenerateCodeVerifier() (string, error) {
const targetLen = 32
charsetLen := len(codeVerifierCharset)
limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(result), nil
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -257,9 +258,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 存在唯一键约束 生成tombstone key 用来释放原key长度远小于 128满足 schema 限制
tombstoneKey := fmt.Sprintf("__deleted__%d__%d", id, time.Now().UnixNano())
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetKey(tombstoneKey).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {

View File

@@ -151,6 +151,31 @@ func (s *APIKeyRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
func (s *APIKeyRepoSuite) TestCreate_AfterSoftDelete_AllowsSameKey() {
user := s.mustCreateUser("recreate-after-soft-delete@test.com")
const reusedKey = "sk-reuse-after-soft-delete"
first := &service.APIKey{
UserID: user.ID,
Key: reusedKey,
Name: "First Key",
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, first), "create first key")
s.Require().NoError(s.repo.Delete(s.ctx, first.ID), "soft delete first key")
second := &service.APIKey{
UserID: user.ID,
Key: reusedKey,
Name: "Second Key",
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, second), "create second key with same key")
s.Require().NotZero(second.ID)
s.Require().NotEqual(first.ID, second.ID, "recreated key should be a new row")
}
// --- ListByUserID / CountByUserID ---
func (s *APIKeyRepoSuite) TestListByUserID() {

View File

@@ -0,0 +1,55 @@
package repository
import (
"context"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
internal500CounterPrefix = "internal500_count:account:"
internal500CounterTTLSeconds = 86400 // 24 小时兜底
)
// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
// 如果 key 不存在,则创建并设置过期时间
var internal500CounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type internal500CounterCache struct {
rdb *redis.Client
}
// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例
func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache {
return &internal500CounterCache{rdb: rdb}
}
// IncrementInternal500Count 原子递增计数并返回当前值
func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) {
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment internal500 count: %w", err)
}
return result, nil
}
// ResetInternal500Count 清零计数器(成功响应时调用)
func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -81,6 +81,7 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,

View File

@@ -1229,6 +1229,28 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
return false
}
// IsCustomBaseURLEnabled 检查是否启用自定义 base URL 中继转发
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
func (a *Account) IsCustomBaseURLEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["custom_base_url_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomBaseURL 返回自定义中继服务的 base URL
func (a *Account) GetCustomBaseURL() string {
return a.GetExtraString("custom_base_url")
}
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型5m 或 1h

View File

@@ -1866,6 +1866,18 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
if err := s.accountRepo.ClearError(ctx, id); err != nil {
return nil, err
}
if err := s.accountRepo.ClearRateLimit(ctx, id); err != nil {
return nil, err
}
if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, id); err != nil {
return nil, err
}
if err := s.accountRepo.ClearModelRateLimits(ctx, id); err != nil {
return nil, err
}
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}

View File

@@ -0,0 +1,86 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini
account *Account
clearErrorCalls int
clearRateLimitCalls int
clearAntigravityCalls int
clearModelRateLimitCalls int
clearTempUnschedCalls int
}
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
return r.account, nil
}
func (r *accountRepoStubForClearAccountError) ClearError(ctx context.Context, id int64) error {
r.clearErrorCalls++
r.account.Status = StatusActive
r.account.ErrorMessage = ""
return nil
}
func (r *accountRepoStubForClearAccountError) ClearRateLimit(ctx context.Context, id int64) error {
r.clearRateLimitCalls++
r.account.RateLimitedAt = nil
r.account.RateLimitResetAt = nil
return nil
}
func (r *accountRepoStubForClearAccountError) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
r.clearAntigravityCalls++
return nil
}
func (r *accountRepoStubForClearAccountError) ClearModelRateLimits(ctx context.Context, id int64) error {
r.clearModelRateLimitCalls++
return nil
}
func (r *accountRepoStubForClearAccountError) ClearTempUnschedulable(ctx context.Context, id int64) error {
r.clearTempUnschedCalls++
r.account.TempUnschedulableUntil = nil
r.account.TempUnschedulableReason = ""
return nil
}
func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *testing.T) {
until := time.Now().Add(10 * time.Minute)
resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{
account: &Account{
ID: 31,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "refresh failed",
RateLimitResetAt: &resetAt,
TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token",
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.ClearAccountError(context.Background(), 31)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.clearErrorCalls)
require.Equal(t, 1, repo.clearRateLimitCalls)
require.Equal(t, 1, repo.clearAntigravityCalls)
require.Equal(t, 1, repo.clearModelRateLimitCalls)
require.Equal(t, 1, repo.clearTempUnschedCalls)
require.Nil(t, updated.RateLimitResetAt)
require.Nil(t, updated.TempUnschedulableUntil)
require.Empty(t, updated.TempUnschedulableReason)
}

View File

@@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
@@ -766,10 +767,19 @@ urlFallbackLoop:
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
// 追踪 INTERNAL 500非匹配的 attempt 清除标记
if !isAntigravityInternalServerError(resp.StatusCode, respBody) {
allAttemptsInternal500 = false
}
continue
}
}
// INTERNAL 500 渐进惩罚3 次重试全部命中特定 500 时递增计数器并惩罚
if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) {
s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account)
}
// 其他 4xx 错误或重试用尽,直接返回
resp = &http.Response{
StatusCode: resp.StatusCode,
@@ -788,6 +798,11 @@ urlFallbackLoop:
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
}
// 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry
if resp != nil && resp.StatusCode < 400 {
s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID)
}
return &antigravityRetryLoopResult{resp: resp}, nil
}
@@ -862,6 +877,7 @@ type AntigravityGatewayService struct {
settingService *SettingService
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
}
func NewAntigravityGatewayService(
@@ -872,6 +888,7 @@ func NewAntigravityGatewayService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
internal500Cache Internal500CounterCache,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
@@ -881,6 +898,7 @@ func NewAntigravityGatewayService(
settingService: settingService,
cache: cache,
schedulerSnapshot: schedulerSnapshot,
internal500Cache: internal500Cache,
}
}

View File

@@ -0,0 +1,97 @@
package service
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/tidwall/gjson"
)
// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长
const (
internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟
internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时
internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用
)
// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误
// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL"
func isAntigravityInternalServerError(statusCode int, body []byte) bool {
if statusCode != http.StatusInternalServerError {
return false
}
return gjson.GetBytes(body, "error.code").Int() == 500 &&
gjson.GetBytes(body, "error.message").String() == "Internal error encountered." &&
gjson.GetBytes(body, "error.status").String() == "INTERNAL"
}
// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚
// count=1: temp_unschedulable 10 分钟
// count=2: temp_unschedulable 10 小时
// count>=3: SetError 永久禁用
func (s *AntigravityGatewayService) applyInternal500Penalty(
ctx context.Context, prefix string, account *Account, count int64,
) {
switch {
case count >= int64(internal500PenaltyTier3Threshold):
reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count)
if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil {
slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err)
return
}
slog.Warn("internal500_account_disabled",
"account_id", account.ID, "account_name", account.Name, "consecutive_count", count)
case count == 2:
until := time.Now().Add(internal500PenaltyTier2Duration)
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
return
}
slog.Warn("internal500_temp_unschedulable",
"account_id", account.ID, "account_name", account.Name,
"duration", internal500PenaltyTier2Duration, "consecutive_count", count)
case count == 1:
until := time.Now().Add(internal500PenaltyTier1Duration)
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("internal500_temp_unschedulable",
"account_id", account.ID, "account_name", account.Name,
"duration", internal500PenaltyTier1Duration, "consecutive_count", count)
}
}
// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚
func (s *AntigravityGatewayService) handleInternal500RetryExhausted(
ctx context.Context, prefix string, account *Account,
) {
if s.internal500Cache == nil {
return
}
count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID)
if err != nil {
slog.Error("internal500_counter_increment_failed",
"prefix", prefix, "account_id", account.ID, "error", err)
return
}
s.applyInternal500Penalty(ctx, prefix, account, count)
}
// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器
func (s *AntigravityGatewayService) resetInternal500Counter(
ctx context.Context, prefix string, accountID int64,
) {
if s.internal500Cache == nil {
return
}
if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil {
slog.Error("internal500_counter_reset_failed",
"prefix", prefix, "account_id", accountID, "error", err)
}
}

View File

@@ -0,0 +1,321 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- mock: Internal500CounterCache ---
type mockInternal500Cache struct {
incrementCount int64
incrementErr error
resetErr error
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
}
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
m.incrementCalls = append(m.incrementCalls, accountID)
return m.incrementCount, m.incrementErr
}
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
m.resetCalls = append(m.resetCalls, accountID)
return m.resetErr
}
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
type internal500AccountRepoStub struct {
AccountRepository // 嵌入接口,未实现的方法会 panic不应被调用
tempUnschedCalls []tempUnschedCall
setErrorCalls []setErrorCall
}
type tempUnschedCall struct {
accountID int64
until time.Time
reason string
}
type setErrorCall struct {
accountID int64
reason string
}
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
return nil
}
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
return nil
}
// =============================================================================
// TestIsAntigravityInternalServerError
// =============================================================================
func TestIsAntigravityInternalServerError(t *testing.T) {
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.True(t, isAntigravityInternalServerError(500, body))
})
t.Run("statusCode 不是 500", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(429, body))
require.False(t, isAntigravityInternalServerError(503, body))
require.False(t, isAntigravityInternalServerError(200, body))
})
t.Run("body 中 message 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 status 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 code 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("空 body", func(t *testing.T) {
require.False(t, isAntigravityInternalServerError(500, []byte{}))
require.False(t, isAntigravityInternalServerError(500, nil))
})
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
body := []byte(`Internal Server Error`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
}
// =============================================================================
// TestApplyInternal500Penalty
// =============================================================================
func TestApplyInternal500Penalty(t *testing.T) {
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(1), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
// until 应在 [before+10m, after+10m] 范围内
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
})
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(2), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
})
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(3), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
})
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 5, Name: "acc-5"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(5), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
})
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
}
// =============================================================================
// TestHandleInternal500RetryExhausted
// =============================================================================
func TestHandleInternal500RetryExhausted(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: nil,
}
account := &Account{ID: 1, Name: "acc-1"}
// 不应 panic
require.NotPanics(t, func() {
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
})
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementErr: errors.New("redis connection error"),
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 2, Name: "acc-2"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(2), cache.incrementCalls[0])
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 1,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 3, Name: "acc-3"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(3), cache.incrementCalls[0])
// tier1: SetTempUnschedulable
require.Len(t, repo.tempUnschedCalls, 1)
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 3,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 4, Name: "acc-4"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
})
}
// =============================================================================
// TestResetInternal500Counter
// =============================================================================
func TestResetInternal500Counter(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
svc := &AntigravityGatewayService{
internal500Cache: nil,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 1)
})
})
t.Run("ResetInternal500Count 返回 error → 不 panic仅日志", func(t *testing.T) {
cache := &mockInternal500Cache{
resetErr: errors.New("redis timeout"),
}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 42)
})
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(42), cache.resetCalls[0])
})
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
cache := &mockInternal500Cache{}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
svc.resetInternal500Counter(context.Background(), "[test]", 99)
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(99), cache.resetCalls[0])
})
}

View File

@@ -12,6 +12,7 @@ import (
"log/slog"
mathrand "math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
@@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{
"user-agent": true,
"content-type": true,
"accept-encoding": true,
"x-claude-code-session-id": true,
"x-client-request-id": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
@@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err
}
// 获取代理URL
// 获取代理URL(自定义 base URL 模式下proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
}
// 解析 TLS 指纹 profile同一请求生命周期内不变避免重试循环中重复解析
@@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
targetURL = validatedURL + "/v1/messages?beta=true"
}
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
}
clientHeaders := http.Header{}
@@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
// === DEBUG: 打印上游转发请求headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(),
@@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err
}
// 获取代理URL
// 获取代理URL(自定义 base URL 模式下proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
}
// 发送请求
@@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
}
clientHeaders := http.Header{}
@@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
@@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
// buildCustomRelayURL 构建自定义中继转发 URL
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
u += "&proxy=" + url.QueryEscape(proxyURL)
}
}
return u
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)

View File

@@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{
"sec-fetch-mode": "sec-fetch-mode",
"accept-encoding": "accept-encoding",
"authorization": "authorization",
// Claude Code 2.1.87+ 新增 header
"x-claude-code-session-id": "X-Claude-Code-Session-Id",
"x-client-request-id": "x-client-request-id",
"content-length": "content-length",
}
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
@@ -55,11 +60,14 @@ var headerWireOrder = []string{
"authorization",
"x-app",
"User-Agent",
"X-Claude-Code-Session-Id",
"content-type",
"anthropic-beta",
"x-client-request-id",
"accept-language",
"sec-fetch-mode",
"accept-encoding",
"content-length",
"x-stainless-helper-method",
}

View File

@@ -0,0 +1,11 @@
package service
import "context"
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
type Internal500CounterCache interface {
// IncrementInternal500Count 原子递增计数并返回当前值
IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error)
// ResetInternal500Count 清零计数器(成功响应时调用)
ResetInternal500Count(ctx context.Context, accountID int64) error
}

View File

@@ -0,0 +1,103 @@
package service
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
func NormalizeOpenAICompatRequestedModel(model string) string {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return ""
}
normalized, _, ok := splitOpenAICompatReasoningModel(trimmed)
if !ok || normalized == "" {
return trimmed
}
return normalized
}
func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) {
if req == nil {
return
}
originalModel := strings.TrimSpace(req.Model)
if originalModel == "" {
return
}
normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel)
if hasReasoningSuffix && normalizedModel != "" {
req.Model = normalizedModel
}
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
return
}
claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort)
if claudeEffort == "" {
return
}
if req.OutputConfig == nil {
req.OutputConfig = &apicompat.AnthropicOutputConfig{}
}
req.OutputConfig.Effort = claudeEffort
}
func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return "", "", false
}
modelID := trimmed
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
modelID = strings.TrimSpace(modelID)
if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") {
return trimmed, "", false
}
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
switch r {
case '-', '_', ' ':
return true
default:
return false
}
})
if len(parts) == 0 {
return trimmed, "", false
}
last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1])
switch last {
case "none", "minimal":
case "low", "medium", "high":
reasoningEffort = last
case "xhigh", "extrahigh":
reasoningEffort = "xhigh"
default:
return trimmed, "", false
}
return normalizeCodexModel(modelID), reasoningEffort, true
}
func openAIReasoningEffortToClaudeOutputEffort(effort string) string {
switch strings.TrimSpace(effort) {
case "low", "medium", "high":
return effort
case "xhigh":
return "max"
default:
return ""
}
}

View File

@@ -0,0 +1,129 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"},
{name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"},
{name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"},
{name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input))
})
}
}
func TestApplyOpenAICompatModelNormalization(t *testing.T) {
t.Parallel()
t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) {
req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "gpt-5.4", req.Model)
require.NotNil(t, req.OutputConfig)
require.Equal(t, "max", req.OutputConfig.Effort)
})
t.Run("explicit output config wins over model suffix", func(t *testing.T) {
req := &apicompat.AnthropicRequest{
Model: "gpt-5.4-xhigh",
OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"},
}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "gpt-5.4", req.Model)
require.NotNil(t, req.OutputConfig)
require.Equal(t, "low", req.OutputConfig.Effort)
})
t.Run("non openai model is untouched", func(t *testing.T) {
req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "claude-opus-4-6", req.Model)
require.Nil(t, req.OutputConfig)
})
}
func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4-xhigh", result.Model)
require.Equal(t, "gpt-5.4", result.UpstreamModel)
require.Equal(t, "gpt-5.4", result.BillingModel)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "xhigh", *result.ReasoningEffort)
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String())
require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String())
t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String())
}

View File

@@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("parse anthropic request: %w", err)
}
originalModel := anthropicReq.Model
applyOpenAICompatModelNormalization(&anthropicReq)
clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses
@@ -59,7 +60,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
// 3. Model mapping
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel)
responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied",

View File

@@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
// Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex").
// This ensures pricing is always based on the model the user requested.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)

View File

@@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)

View File

@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
accessToken := account.GetCredential("access_token")
if accessToken != "" {
tokenInfo := &OpenAITokenInfo{
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
}
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
tokenInfo.ExpiresAt = expiresAt.Unix()
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
}
return tokenInfo, nil
}
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}

View File

@@ -0,0 +1,54 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientRefreshStub struct {
refreshCalls int32
}
func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
client := &openaiOAuthClientRefreshStub{}
svc := NewOpenAIOAuthService(nil, client)
expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "existing-access-token",
"expires_at": expiresAt,
"client_id": "client-id-1",
},
}
info, err := svc.RefreshAccountToken(context.Background(), account)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "existing-access-token", info.AccessToken)
require.Equal(t, "client-id-1", info.ClientID)
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
}

View File

@@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error {
return s.downloadPricingData()
}
// 检查文件是否过期
// 先加载本地文件(确保服务可用),再检查是否需要更新
if err := s.loadPricingData(pricingFile); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL通过远程哈希检查是否有更新
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
if err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err)
return nil // 已加载本地文件,哈希获取失败不影响启动
}
s.mu.RLock()
localHash := s.localHash
s.mu.RUnlock()
if localHash == "" || remoteHash != localHash {
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...",
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
if err := s.downloadPricingData(); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err)
}
}
return nil
}
// 没有哈希URL时基于文件年龄检查
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
return nil // 已加载本地文件
}
fileAge := time.Since(info.ModTime())
@@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error {
}
}
// 加载本地文件
return s.loadPricingData(pricingFile)
return nil
}
// syncWithRemote 与远程同步(基于哈希校验)
func (s *PricingService) syncWithRemote() error {
pricingFile := s.getPricingFilePath()
// 计算本地文件哈希
localHash, err := s.computeFileHash(pricingFile)
if err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL从远程获取哈希进行比对
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
@@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error {
return nil // 哈希获取失败不影响正常使用
}
if remoteHash != localHash {
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...")
s.mu.RLock()
localHash := s.localHash
s.mu.RUnlock()
if localHash == "" || remoteHash != localHash {
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...",
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
return s.downloadPricingData()
}
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed")
@@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error {
}
// 没有哈希URL时基于时间检查
pricingFile := s.getPricingFilePath()
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
@@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var expectedHash string
// 获取远程哈希(用于同步锚点,不作为完整性校验)
var remoteHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
remoteHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err)
}
}
@@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error {
return fmt.Errorf("download failed: %w", err)
}
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
// 哈希校验:不匹配时仅告警,不阻止更新
// 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件)
dataHash := sha256.Sum256(body)
dataHashStr := hex.EncodeToString(dataHash[:])
if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) {
logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)",
remoteHash[:min(8, len(remoteHash))], dataHashStr[:8])
}
// 解析JSON数据使用灵活的解析方式
@@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err)
}
// 保存哈希
hash := sha256.Sum256(body)
hashStr := hex.EncodeToString(hash[:])
// 使用远程哈希作为同步锚点,防止重复下载
// 当远程哈希不可用时,回退到数据本身的哈希
syncHash := dataHashStr
if remoteHash != "" {
syncHash = remoteHash
}
hashFile := s.getHashFilePath()
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err)
}
@@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error {
s.mu.Lock()
s.pricingData = data
s.lastUpdated = time.Now()
s.localHash = hashStr
s.localHash = syncHash
s.mu.Unlock()
logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data))
@@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) {
return normalized, nil
}
// computeFileHash 计算文件哈希
func (s *PricingService) computeFileHash(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
s.mu.RLock()

View File

@@ -32,8 +32,9 @@ type TokenRefreshService struct {
privacyClientFactory PrivacyClientFactory
proxyRepo ProxyRepository
stopCh chan struct{}
wg sync.WaitGroup
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
@@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() {
// Stop 停止刷新服务(可安全多次调用)
func (s *TokenRefreshService) Stop() {
close(s.stopCh)
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
slog.Info("token_refresh.service_stopped")
}
@@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool {
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
"no refresh token available",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {

View File

@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
updateCredentialsCalls int
setErrorCalls int
clearTempCalls int
setTempUnschedCalls int
lastAccount *Account
updateErr error
}
@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
return nil
}
func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.setTempUnschedCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
}
}
func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("no refresh token available"),
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable")
require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state")
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}

View File

@@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string {
}
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" {
return trimmedUpstream
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
return trimmed
}
return strings.TrimSpace(requestedModel)
return strings.TrimSpace(upstreamModel)
}

View File

@@ -865,10 +865,10 @@ rate_limit:
pricing:
# URL to fetch model pricing data (default: pinned model-price-repo commit)
# 获取模型定价数据的 URL默认固定 commit 的 model-price-repo
remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json"
remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json"
# Hash verification URL (optional)
# 哈希校验 URL可选
hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256"
hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256"
# Local data directory for caching
# 本地数据缓存目录
data_dir: "./data"

View File

@@ -2245,6 +2245,41 @@
</p>
</div>
</div>
<!-- Custom Base URL Relay -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
</p>
</div>
<button
type="button"
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="customBaseUrlEnabled" class="mt-3">
<input
v-model="customBaseUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
/>
</div>
</div>
</div>
<div>
@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const sessionIdMaskingEnabled = ref(false)
const cacheTTLOverrideEnabled = ref(false)
const cacheTTLOverrideTarget = ref<string>('5m')
const customBaseUrlEnabled = ref(false)
const customBaseUrl = ref('')
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
@@ -3765,6 +3802,8 @@ const resetForm = () => {
sessionIdMaskingEnabled.value = false
cacheTTLOverrideEnabled.value = false
cacheTTLOverrideTarget.value = '5m'
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
allowOverages.value = false
antigravityAccountType.value = 'oauth'
upstreamBaseUrl.value = ''
@@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
// Add custom base URL settings
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
extra.custom_base_url_enabled = true
extra.custom_base_url = customBaseUrl.value.trim()
}
const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
@@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
// Add custom base URL settings
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
extra.custom_base_url_enabled = true
extra.custom_base_url = customBaseUrl.value.trim()
}
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
const credentials: Record<string, unknown> = { ...tokenInfo }

View File

@@ -1580,6 +1580,41 @@
</p>
</div>
</div>
<!-- Custom Base URL Relay -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
</p>
</div>
<button
type="button"
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="customBaseUrlEnabled" class="mt-3">
<input
v-model="customBaseUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
/>
</div>
</div>
</div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const sessionIdMaskingEnabled = ref(false)
const cacheTTLOverrideEnabled = ref(false)
const cacheTTLOverrideTarget = ref<string>('5m')
const customBaseUrlEnabled = ref(false)
const customBaseUrl = ref('')
// OpenAI 自动透传开关OAuth/API Key
const openaiPassthroughEnabled = ref(false)
@@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) {
sessionIdMaskingEnabled.value = false
cacheTTLOverrideEnabled.value = false
cacheTTLOverrideTarget.value = '5m'
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
// Only applies to Anthropic OAuth/SetupToken accounts
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
@@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) {
cacheTTLOverrideEnabled.value = true
cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m'
}
// Load custom base URL setting
if (account.custom_base_url_enabled === true) {
customBaseUrlEnabled.value = true
customBaseUrl.value = account.custom_base_url || ''
}
}
function formatTempUnschedKeywords(value: unknown) {
@@ -2980,6 +3025,15 @@ const handleSubmit = async () => {
delete newExtra.cache_ttl_override_target
}
// Custom base URL relay setting
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
newExtra.custom_base_url_enabled = true
newExtra.custom_base_url = customBaseUrl.value.trim()
} else {
delete newExtra.custom_base_url_enabled
delete newExtra.custom_base_url
}
updatePayload.extra = newExtra
}

View File

@@ -2318,6 +2318,11 @@ export default {
target: 'Target TTL',
targetHint: 'Select the TTL tier for billing'
},
customBaseUrl: {
label: 'Custom Relay URL',
hint: 'Forward requests to a custom relay service. Proxy URL will be passed as a query parameter.',
urlHint: 'Relay service URL (e.g., https://relay.example.com)',
},
clientAffinity: {
label: 'Client Affinity Scheduling',
hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching'
@@ -4378,6 +4383,7 @@ export default {
provider: 'Type',
active: 'Active',
endpoint: 'Endpoint',
bucket: 'Bucket',
storagePath: 'Storage Path',
capacityUsage: 'Capacity / Used',
capacityUnlimited: 'Unlimited',

View File

@@ -2462,6 +2462,11 @@ export default {
target: '目标 TTL',
targetHint: '选择计费使用的 TTL 类型'
},
customBaseUrl: {
label: '自定义转发地址',
hint: '启用后将请求转发到自定义中继服务,代理地址将作为 URL 参数传递给中继服务',
urlHint: '中继服务地址(如 https://relay.example.com',
},
clientAffinity: {
label: '客户端亲和调度',
hint: '启用后,新会话会优先调度到该客户端之前使用过的账号,避免频繁切换账号'
@@ -4542,6 +4547,7 @@ export default {
provider: '存储类型',
active: '生效状态',
endpoint: '端点',
bucket: '存储桶',
storagePath: '存储路径',
capacityUsage: '容量 / 已用',
capacityUnlimited: '无限制',

View File

@@ -734,6 +734,10 @@ export interface Account {
cache_ttl_override_enabled?: boolean | null
cache_ttl_override_target?: string | null
// 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效)
custom_base_url_enabled?: boolean | null
custom_base_url?: string | null
// 客户端亲和调度(仅 Anthropic/Antigravity 平台有效)
// 启用后新会话会优先调度到客户端之前使用过的账号
client_affinity_enabled?: boolean | null

View File

@@ -2,6 +2,7 @@
import { computed, onMounted, reactive, ref, watch } from 'vue'
import { opsAPI, type OpsRuntimeLogConfig, type OpsSystemLog, type OpsSystemLogSinkHealth } from '@/api/admin/ops'
import Pagination from '@/components/common/Pagination.vue'
import Select from '@/components/common/Select.vue'
import { useAppStore } from '@/stores'
const appStore = useAppStore()
@@ -56,6 +57,37 @@ const filters = reactive({
q: ''
})
const runtimeLevelOptions = [
{ value: 'debug', label: 'debug' },
{ value: 'info', label: 'info' },
{ value: 'warn', label: 'warn' },
{ value: 'error', label: 'error' }
]
const stacktraceLevelOptions = [
{ value: 'none', label: 'none' },
{ value: 'error', label: 'error' },
{ value: 'fatal', label: 'fatal' }
]
const timeRangeOptions = [
{ value: '5m', label: '5m' },
{ value: '30m', label: '30m' },
{ value: '1h', label: '1h' },
{ value: '6h', label: '6h' },
{ value: '24h', label: '24h' },
{ value: '7d', label: '7d' },
{ value: '30d', label: '30d' }
]
const filterLevelOptions = [
{ value: '', label: '全部' },
{ value: 'debug', label: 'debug' },
{ value: 'info', label: 'info' },
{ value: 'warn', label: 'warn' },
{ value: 'error', label: 'error' }
]
const levelBadgeClass = (level: string) => {
const v = String(level || '').toLowerCase()
if (v === 'error' || v === 'fatal') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
@@ -347,20 +379,11 @@ onMounted(async () => {
<div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
<label class="text-xs text-gray-600 dark:text-gray-300">
级别
<select v-model="runtimeConfig.level" class="input mt-1">
<option value="debug">debug</option>
<option value="info">info</option>
<option value="warn">warn</option>
<option value="error">error</option>
</select>
<Select v-model="runtimeConfig.level" class="mt-1" :options="runtimeLevelOptions" />
</label>
<label class="text-xs text-gray-600 dark:text-gray-300">
堆栈阈值
<select v-model="runtimeConfig.stacktrace_level" class="input mt-1">
<option value="none">none</option>
<option value="error">error</option>
<option value="fatal">fatal</option>
</select>
<Select v-model="runtimeConfig.stacktrace_level" class="mt-1" :options="stacktraceLevelOptions" />
</label>
<label class="text-xs text-gray-600 dark:text-gray-300">
采样初始
@@ -403,15 +426,7 @@ onMounted(async () => {
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-5">
<label class="text-xs text-gray-600 dark:text-gray-300">
时间范围
<select v-model="filters.time_range" class="input mt-1">
<option value="5m">5m</option>
<option value="30m">30m</option>
<option value="1h">1h</option>
<option value="6h">6h</option>
<option value="24h">24h</option>
<option value="7d">7d</option>
<option value="30d">30d</option>
</select>
<Select v-model="filters.time_range" class="mt-1" :options="timeRangeOptions" />
</label>
<label class="text-xs text-gray-600 dark:text-gray-300">
开始时间可选
@@ -423,13 +438,7 @@ onMounted(async () => {
</label>
<label class="text-xs text-gray-600 dark:text-gray-300">
级别
<select v-model="filters.level" class="input mt-1">
<option value="">全部</option>
<option value="debug">debug</option>
<option value="info">info</option>
<option value="warn">warn</option>
<option value="error">error</option>
</select>
<Select v-model="filters.level" class="mt-1" :options="filterLevelOptions" />
</label>
<label class="text-xs text-gray-600 dark:text-gray-300">
组件