merge: 合并 upstream/main 并保留本地图片计费功能

This commit is contained in:
song
2026-01-06 10:49:26 +08:00
187 changed files with 17081 additions and 19407 deletions

View File

@@ -57,19 +57,24 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: frontend/package-lock.json
cache: 'pnpm'
cache-dependency-path: frontend/pnpm-lock.yaml
- name: Install dependencies
run: npm ci
run: pnpm install --frozen-lockfile
working-directory: frontend
- name: Build frontend
run: npm run build
run: pnpm run build
working-directory: frontend
- name: Upload frontend artifact

2
.gitignore vendored
View File

@@ -33,6 +33,7 @@ frontend/dist/
*.local
*.tsbuildinfo
vite.config.d.ts
vite.config.js.timestamp-*
# 日志
npm-debug.log*
@@ -121,3 +122,4 @@ AGENTS.md
backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
vite.config.js

View File

@@ -19,13 +19,16 @@ FROM ${NODE_IMAGE} AS frontend-builder
WORKDIR /app/frontend
# Install pnpm
RUN corepack enable && corepack prepare pnpm@latest --activate
# Install dependencies first (better caching)
COPY frontend/package*.json ./
RUN npm ci
COPY frontend/package.json frontend/pnpm-lock.yaml ./
RUN pnpm install --frozen-lockfile
# Copy frontend source and build
COPY frontend/ ./
RUN npm run build
RUN pnpm run build
# -----------------------------------------------------------------------------
# Stage 2: Backend Builder

View File

@@ -1,4 +1,4 @@
.PHONY: build build-backend build-frontend
.PHONY: build build-backend build-frontend test test-backend test-frontend
# 一键编译前后端
build: build-backend build-frontend
@@ -10,3 +10,13 @@ build-backend:
# 编译前端(需要已安装依赖)
build-frontend:
@npm --prefix frontend run build
# 运行测试(后端 + 前端)
test: test-backend test-frontend
test-backend:
@$(MAKE) -C backend test
test-frontend:
@npm --prefix frontend run lint:check
@npm --prefix frontend run typecheck

View File

@@ -218,20 +218,23 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Build frontend
# 2. Install pnpm (if not already installed)
npm install -g pnpm
# 3. Build frontend
cd frontend
npm install
npm run build
pnpm install
pnpm run build
# Output will be in ../backend/internal/web/dist/
# 3. Build backend with embedded frontend
# 4. Build backend with embedded frontend
cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. Create configuration file
# 5. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml
# 5. Edit configuration
# 6. Edit configuration
nano config.yaml
```
@@ -268,6 +271,24 @@ default:
rate_multiplier: 1.0
```
Additional security-related options are available in `config.yaml`:
- `cors.allowed_origins` for CORS allowlist
- `security.url_allowlist` for upstream/pricing/CRS host allowlists
- `security.url_allowlist.enabled` to disable URL validation (use with caution)
- `security.url_allowlist.allow_insecure_http` to allow http URLs when validation is disabled
- `security.response_headers.enabled` to enable configurable response header filtering (disabled uses default allowlist)
- `security.csp` to control Content-Security-Policy headers
- `billing.circuit_breaker` to fail closed on billing errors
- `server.trusted_proxies` to enable X-Forwarded-For parsing
- `turnstile.required` to require Turnstile in release mode
If you disable URL validation or response header filtering, harden your network layer:
- Enforce an egress allowlist for upstream domains/IPs
- Block private/loopback/link-local ranges
- Enforce TLS-only outbound traffic
- Strip sensitive upstream response headers at the proxy
```bash
# 6. Run the application
./sub2api
@@ -282,7 +303,7 @@ go run ./cmd/server
# Frontend (with hot reload)
cd frontend
npm run dev
pnpm run dev
```
#### Code Generation

View File

@@ -218,20 +218,23 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 编译前端
# 2. 安装 pnpm如果还没有安装
npm install -g pnpm
# 3. 编译前端
cd frontend
npm install
npm run build
pnpm install
pnpm run build
# 构建产物输出到 ../backend/internal/web/dist/
# 3. 编译后端(嵌入前端)
# 4. 编译后端(嵌入前端)
cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. 创建配置文件
# 5. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml
# 5. 编辑配置
# 6. 编辑配置
nano config.yaml
```
@@ -268,6 +271,24 @@ default:
rate_multiplier: 1.0
```
`config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单
- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
- `security.url_allowlist.enabled` 可关闭 URL 校验(慎用)
- `security.url_allowlist.allow_insecure_http` 关闭校验时允许 http URL
- `security.response_headers.enabled` 可启用可配置响应头过滤(关闭时使用默认白名单)
- `security.csp` 配置 Content-Security-Policy
- `billing.circuit_breaker` 计费异常时 fail-closed
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile
如关闭 URL 校验或响应头过滤,请加强网络层防护:
- 出站访问白名单限制上游域名/IP
- 阻断私网/回环/链路本地地址
- 强制仅允许 TLS 出站
- 在反向代理层移除敏感响应头
```bash
# 6. 运行应用
./sub2api
@@ -282,7 +303,7 @@ go run ./cmd/server
# 前端(支持热重载)
cd frontend
npm run dev
pnpm run dev
```
#### 代码生成

View File

@@ -1,8 +1,12 @@
.PHONY: build test-unit test-integration test-e2e
.PHONY: build test test-unit test-integration test-e2e
build:
go build -o bin/server ./cmd/server
test:
go test ./...
golangci-lint run ./...
test-unit:
go test -tags=unit ./...

View File

@@ -86,7 +86,8 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(middleware.Recovery())
r.Use(middleware.CORS())
r.Use(middleware.CORS(config.CORSConfig{}))
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
// Register setup routes
setup.RegisterRoutes(r)

View File

@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)

View File

@@ -27,6 +27,8 @@ type Account struct {
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Notes holds the value of the "notes" field.
Notes *string `json:"notes,omitempty"`
// Platform holds the value of the "platform" field.
Platform string `json:"platform,omitempty"`
// Type holds the value of the "type" field.
@@ -131,7 +133,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
values[i] = new(sql.NullString)
case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
values[i] = new(sql.NullTime)
@@ -181,6 +183,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Name = value.String
}
case account.FieldNotes:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field notes", values[i])
} else if value.Valid {
_m.Notes = new(string)
*_m.Notes = value.String
}
case account.FieldPlatform:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field platform", values[i])
@@ -366,6 +375,11 @@ func (_m *Account) String() string {
builder.WriteString("name=")
builder.WriteString(_m.Name)
builder.WriteString(", ")
if v := _m.Notes; v != nil {
builder.WriteString("notes=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("platform=")
builder.WriteString(_m.Platform)
builder.WriteString(", ")

View File

@@ -23,6 +23,8 @@ const (
FieldDeletedAt = "deleted_at"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// FieldPlatform holds the string denoting the platform field in the database.
FieldPlatform = "platform"
// FieldType holds the string denoting the type field in the database.
@@ -102,6 +104,7 @@ var Columns = []string{
FieldUpdatedAt,
FieldDeletedAt,
FieldName,
FieldNotes,
FieldPlatform,
FieldType,
FieldCredentials,
@@ -203,6 +206,11 @@ func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByNotes orders the results by the notes field.
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByPlatform orders the results by the platform field.
func ByPlatform(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPlatform, opts...).ToFunc()

View File

@@ -75,6 +75,11 @@ func Name(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldName, v))
}
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
func Notes(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldNotes, v))
}
// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ.
func Platform(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPlatform, v))
@@ -345,6 +350,81 @@ func NameContainsFold(v string) predicate.Account {
return predicate.Account(sql.FieldContainsFold(FieldName, v))
}
// NotesEQ applies the EQ predicate on the "notes" field.
func NotesEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldNotes, v))
}
// NotesNEQ applies the NEQ predicate on the "notes" field.
func NotesNEQ(v string) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldNotes, v))
}
// NotesIn applies the In predicate on the "notes" field.
func NotesIn(vs ...string) predicate.Account {
return predicate.Account(sql.FieldIn(FieldNotes, vs...))
}
// NotesNotIn applies the NotIn predicate on the "notes" field.
func NotesNotIn(vs ...string) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldNotes, vs...))
}
// NotesGT applies the GT predicate on the "notes" field.
func NotesGT(v string) predicate.Account {
return predicate.Account(sql.FieldGT(FieldNotes, v))
}
// NotesGTE applies the GTE predicate on the "notes" field.
func NotesGTE(v string) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldNotes, v))
}
// NotesLT applies the LT predicate on the "notes" field.
func NotesLT(v string) predicate.Account {
return predicate.Account(sql.FieldLT(FieldNotes, v))
}
// NotesLTE applies the LTE predicate on the "notes" field.
func NotesLTE(v string) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldNotes, v))
}
// NotesContains applies the Contains predicate on the "notes" field.
func NotesContains(v string) predicate.Account {
return predicate.Account(sql.FieldContains(FieldNotes, v))
}
// NotesHasPrefix applies the HasPrefix predicate on the "notes" field.
func NotesHasPrefix(v string) predicate.Account {
return predicate.Account(sql.FieldHasPrefix(FieldNotes, v))
}
// NotesHasSuffix applies the HasSuffix predicate on the "notes" field.
func NotesHasSuffix(v string) predicate.Account {
return predicate.Account(sql.FieldHasSuffix(FieldNotes, v))
}
// NotesIsNil applies the IsNil predicate on the "notes" field.
func NotesIsNil() predicate.Account {
return predicate.Account(sql.FieldIsNull(FieldNotes))
}
// NotesNotNil applies the NotNil predicate on the "notes" field.
func NotesNotNil() predicate.Account {
return predicate.Account(sql.FieldNotNull(FieldNotes))
}
// NotesEqualFold applies the EqualFold predicate on the "notes" field.
func NotesEqualFold(v string) predicate.Account {
return predicate.Account(sql.FieldEqualFold(FieldNotes, v))
}
// NotesContainsFold applies the ContainsFold predicate on the "notes" field.
func NotesContainsFold(v string) predicate.Account {
return predicate.Account(sql.FieldContainsFold(FieldNotes, v))
}
// PlatformEQ applies the EQ predicate on the "platform" field.
func PlatformEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPlatform, v))

View File

@@ -73,6 +73,20 @@ func (_c *AccountCreate) SetName(v string) *AccountCreate {
return _c
}
// SetNotes sets the "notes" field.
func (_c *AccountCreate) SetNotes(v string) *AccountCreate {
_c.mutation.SetNotes(v)
return _c
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_c *AccountCreate) SetNillableNotes(v *string) *AccountCreate {
if v != nil {
_c.SetNotes(*v)
}
return _c
}
// SetPlatform sets the "platform" field.
func (_c *AccountCreate) SetPlatform(v string) *AccountCreate {
_c.mutation.SetPlatform(v)
@@ -501,6 +515,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldName, field.TypeString, value)
_node.Name = value
}
if value, ok := _c.mutation.Notes(); ok {
_spec.SetField(account.FieldNotes, field.TypeString, value)
_node.Notes = &value
}
if value, ok := _c.mutation.Platform(); ok {
_spec.SetField(account.FieldPlatform, field.TypeString, value)
_node.Platform = value
@@ -712,6 +730,24 @@ func (u *AccountUpsert) UpdateName() *AccountUpsert {
return u
}
// SetNotes sets the "notes" field.
func (u *AccountUpsert) SetNotes(v string) *AccountUpsert {
u.Set(account.FieldNotes, v)
return u
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *AccountUpsert) UpdateNotes() *AccountUpsert {
u.SetExcluded(account.FieldNotes)
return u
}
// ClearNotes clears the value of the "notes" field.
func (u *AccountUpsert) ClearNotes() *AccountUpsert {
u.SetNull(account.FieldNotes)
return u
}
// SetPlatform sets the "platform" field.
func (u *AccountUpsert) SetPlatform(v string) *AccountUpsert {
u.Set(account.FieldPlatform, v)
@@ -1076,6 +1112,27 @@ func (u *AccountUpsertOne) UpdateName() *AccountUpsertOne {
})
}
// SetNotes sets the "notes" field.
func (u *AccountUpsertOne) SetNotes(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetNotes(v)
})
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateNotes() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateNotes()
})
}
// ClearNotes clears the value of the "notes" field.
func (u *AccountUpsertOne) ClearNotes() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.ClearNotes()
})
}
// SetPlatform sets the "platform" field.
func (u *AccountUpsertOne) SetPlatform(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -1651,6 +1708,27 @@ func (u *AccountUpsertBulk) UpdateName() *AccountUpsertBulk {
})
}
// SetNotes sets the "notes" field.
func (u *AccountUpsertBulk) SetNotes(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetNotes(v)
})
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateNotes() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateNotes()
})
}
// ClearNotes clears the value of the "notes" field.
func (u *AccountUpsertBulk) ClearNotes() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.ClearNotes()
})
}
// SetPlatform sets the "platform" field.
func (u *AccountUpsertBulk) SetPlatform(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -71,6 +71,26 @@ func (_u *AccountUpdate) SetNillableName(v *string) *AccountUpdate {
return _u
}
// SetNotes sets the "notes" field.
func (_u *AccountUpdate) SetNotes(v string) *AccountUpdate {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableNotes(v *string) *AccountUpdate {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *AccountUpdate) ClearNotes() *AccountUpdate {
_u.mutation.ClearNotes()
return _u
}
// SetPlatform sets the "platform" field.
func (_u *AccountUpdate) SetPlatform(v string) *AccountUpdate {
_u.mutation.SetPlatform(v)
@@ -545,6 +565,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(account.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(account.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(account.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(account.FieldPlatform, field.TypeString, value)
}
@@ -814,6 +840,26 @@ func (_u *AccountUpdateOne) SetNillableName(v *string) *AccountUpdateOne {
return _u
}
// SetNotes sets the "notes" field.
func (_u *AccountUpdateOne) SetNotes(v string) *AccountUpdateOne {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableNotes(v *string) *AccountUpdateOne {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *AccountUpdateOne) ClearNotes() *AccountUpdateOne {
_u.mutation.ClearNotes()
return _u
}
// SetPlatform sets the "platform" field.
func (_u *AccountUpdateOne) SetPlatform(v string) *AccountUpdateOne {
_u.mutation.SetPlatform(v)
@@ -1318,6 +1364,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(account.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(account.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(account.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(account.FieldPlatform, field.TypeString, value)
}

View File

@@ -70,6 +70,7 @@ var (
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "platform", Type: field.TypeString, Size: 50},
{Name: "type", Type: field.TypeString, Size: 20},
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
@@ -96,7 +97,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[21]},
Columns: []*schema.Column{AccountsColumns[22]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -105,52 +106,52 @@ var (
{
Name: "account_platform",
Unique: false,
Columns: []*schema.Column{AccountsColumns[5]},
Columns: []*schema.Column{AccountsColumns[6]},
},
{
Name: "account_type",
Unique: false,
Columns: []*schema.Column{AccountsColumns[6]},
Columns: []*schema.Column{AccountsColumns[7]},
},
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[12]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[21]},
Columns: []*schema.Column{AccountsColumns[22]},
},
{
Name: "account_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[10]},
Columns: []*schema.Column{AccountsColumns[11]},
},
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[14]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[14]},
Columns: []*schema.Column{AccountsColumns[15]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[15]},
Columns: []*schema.Column{AccountsColumns[16]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[16]},
Columns: []*schema.Column{AccountsColumns[17]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[17]},
Columns: []*schema.Column{AccountsColumns[18]},
},
{
Name: "account_deleted_at",

View File

@@ -994,6 +994,7 @@ type AccountMutation struct {
updated_at *time.Time
deleted_at *time.Time
name *string
notes *string
platform *string
_type *string
credentials *map[string]interface{}
@@ -1281,6 +1282,55 @@ func (m *AccountMutation) ResetName() {
m.name = nil
}
// SetNotes sets the "notes" field.
func (m *AccountMutation) SetNotes(s string) {
m.notes = &s
}
// Notes returns the value of the "notes" field in the mutation.
func (m *AccountMutation) Notes() (r string, exists bool) {
v := m.notes
if v == nil {
return
}
return *v, true
}
// OldNotes returns the old "notes" field's value of the Account entity.
// If the Account object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AccountMutation) OldNotes(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldNotes is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldNotes requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldNotes: %w", err)
}
return oldValue.Notes, nil
}
// ClearNotes clears the value of the "notes" field.
func (m *AccountMutation) ClearNotes() {
m.notes = nil
m.clearedFields[account.FieldNotes] = struct{}{}
}
// NotesCleared returns if the "notes" field was cleared in this mutation.
func (m *AccountMutation) NotesCleared() bool {
_, ok := m.clearedFields[account.FieldNotes]
return ok
}
// ResetNotes resets all changes to the "notes" field.
func (m *AccountMutation) ResetNotes() {
m.notes = nil
delete(m.clearedFields, account.FieldNotes)
}
// SetPlatform sets the "platform" field.
func (m *AccountMutation) SetPlatform(s string) {
m.platform = &s
@@ -2219,7 +2269,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
fields := make([]string, 0, 21)
fields := make([]string, 0, 22)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -2232,6 +2282,9 @@ func (m *AccountMutation) Fields() []string {
if m.name != nil {
fields = append(fields, account.FieldName)
}
if m.notes != nil {
fields = append(fields, account.FieldNotes)
}
if m.platform != nil {
fields = append(fields, account.FieldPlatform)
}
@@ -2299,6 +2352,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.DeletedAt()
case account.FieldName:
return m.Name()
case account.FieldNotes:
return m.Notes()
case account.FieldPlatform:
return m.Platform()
case account.FieldType:
@@ -2350,6 +2405,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldDeletedAt(ctx)
case account.FieldName:
return m.OldName(ctx)
case account.FieldNotes:
return m.OldNotes(ctx)
case account.FieldPlatform:
return m.OldPlatform(ctx)
case account.FieldType:
@@ -2421,6 +2478,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetName(v)
return nil
case account.FieldNotes:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetNotes(v)
return nil
case account.FieldPlatform:
v, ok := value.(string)
if !ok {
@@ -2600,6 +2664,9 @@ func (m *AccountMutation) ClearedFields() []string {
if m.FieldCleared(account.FieldDeletedAt) {
fields = append(fields, account.FieldDeletedAt)
}
if m.FieldCleared(account.FieldNotes) {
fields = append(fields, account.FieldNotes)
}
if m.FieldCleared(account.FieldProxyID) {
fields = append(fields, account.FieldProxyID)
}
@@ -2644,6 +2711,9 @@ func (m *AccountMutation) ClearField(name string) error {
case account.FieldDeletedAt:
m.ClearDeletedAt()
return nil
case account.FieldNotes:
m.ClearNotes()
return nil
case account.FieldProxyID:
m.ClearProxyID()
return nil
@@ -2691,6 +2761,9 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldName:
m.ResetName()
return nil
case account.FieldNotes:
m.ResetNotes()
return nil
case account.FieldPlatform:
m.ResetPlatform()
return nil

View File

@@ -124,7 +124,7 @@ func init() {
}
}()
// accountDescPlatform is the schema descriptor for platform field.
accountDescPlatform := accountFields[1].Descriptor()
accountDescPlatform := accountFields[2].Descriptor()
// account.PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
account.PlatformValidator = func() func(string) error {
validators := accountDescPlatform.Validators
@@ -142,7 +142,7 @@ func init() {
}
}()
// accountDescType is the schema descriptor for type field.
accountDescType := accountFields[2].Descriptor()
accountDescType := accountFields[3].Descriptor()
// account.TypeValidator is a validator for the "type" field. It is called by the builders before save.
account.TypeValidator = func() func(string) error {
validators := accountDescType.Validators
@@ -160,33 +160,33 @@ func init() {
}
}()
// accountDescCredentials is the schema descriptor for credentials field.
accountDescCredentials := accountFields[3].Descriptor()
accountDescCredentials := accountFields[4].Descriptor()
// account.DefaultCredentials holds the default value on creation for the credentials field.
account.DefaultCredentials = accountDescCredentials.Default.(func() map[string]interface{})
// accountDescExtra is the schema descriptor for extra field.
accountDescExtra := accountFields[4].Descriptor()
accountDescExtra := accountFields[5].Descriptor()
// account.DefaultExtra holds the default value on creation for the extra field.
account.DefaultExtra = accountDescExtra.Default.(func() map[string]interface{})
// accountDescConcurrency is the schema descriptor for concurrency field.
accountDescConcurrency := accountFields[6].Descriptor()
accountDescConcurrency := accountFields[7].Descriptor()
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
// accountDescPriority is the schema descriptor for priority field.
accountDescPriority := accountFields[7].Descriptor()
accountDescPriority := accountFields[8].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[8].Descriptor()
accountDescStatus := accountFields[9].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[11].Descriptor()
accountDescSchedulable := accountFields[12].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[17].Descriptor()
accountDescSessionWindowStatus := accountFields[18].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()

View File

@@ -54,6 +54,11 @@ func (Account) Fields() []ent.Field {
field.String("name").
MaxLen(100).
NotEmpty(),
// notes: 管理员备注(可为空)
field.String("notes").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "text"}),
// platform: 所属平台,如 "claude", "gemini", "openai" 等
field.String("platform").

View File

@@ -2,7 +2,11 @@
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"os"
"strings"
"time"
@@ -14,6 +18,8 @@ const (
RunModeSimple = "simple"
)
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
@@ -30,6 +36,10 @@ const (
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
@@ -37,6 +47,7 @@ type Config struct {
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
@@ -95,11 +106,65 @@ type PricingConfig struct {
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
}
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowCredentials bool `mapstructure:"allow_credentials"`
}
type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
type URLAllowlistConfig struct {
Enabled bool `mapstructure:"enabled"`
UpstreamHosts []string `mapstructure:"upstream_hosts"`
PricingHosts []string `mapstructure:"pricing_hosts"`
CRSHosts []string `mapstructure:"crs_hosts"`
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
// 关闭 URL 白名单校验时,是否允许 http URL默认只允许 https
AllowInsecureHTTP bool `mapstructure:"allow_insecure_http"`
}
type ResponseHeaderConfig struct {
Enabled bool `mapstructure:"enabled"`
AdditionalAllowed []string `mapstructure:"additional_allowed"`
ForceRemove []string `mapstructure:"force_remove"`
}
type CSPConfig struct {
Enabled bool `mapstructure:"enabled"`
Policy string `mapstructure:"policy"`
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
}
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
}
type CircuitBreakerConfig struct {
Enabled bool `mapstructure:"enabled"`
FailureThreshold int `mapstructure:"failure_threshold"`
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
HalfOpenRequests int `mapstructure:"half_open_requests"`
}
type ConcurrencyConfig struct {
// PingInterval: 并发等待期间的 SSE ping 间隔(秒)
PingInterval int `mapstructure:"ping_interval"`
}
// GatewayConfig API网关相关配置
@@ -134,6 +199,13 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数0使用默认值
MaxLineSize int `mapstructure:"max_line_size"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
@@ -237,6 +309,10 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
@@ -263,8 +339,19 @@ func NormalizeRunMode(value string) string {
func Load() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
// Add config paths in priority order
// 1. DATA_DIR environment variable (highest priority)
if dataDir := os.Getenv("DATA_DIR"); dataDir != "" {
viper.AddConfigPath(dataDir)
}
// 2. Docker data directory
viper.AddConfigPath("/app/data")
// 3. Current directory
viper.AddConfigPath(".")
// 4. Config subdirectory
viper.AddConfigPath("./config")
// 5. System config directory
viper.AddConfigPath("/etc/sub2api")
// 环境变量支持
@@ -287,11 +374,46 @@ func Load() (*Config, error) {
}
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
if cfg.JWT.Secret == "" {
secret, err := generateJWTSecret(64)
if err != nil {
return nil, fmt.Errorf("generate jwt secret error: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
}
if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
}
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
)
}
return &cfg, nil
}
@@ -304,6 +426,45 @@ func setDefaults() {
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
// Security
viper.SetDefault("security.url_allowlist.enabled", false)
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
"api.openai.com",
"api.anthropic.com",
"api.kimi.com",
"open.bigmodel.cn",
"api.minimaxi.com",
"generativelanguage.googleapis.com",
"cloudcode-pa.googleapis.com",
"*.openai.azure.com",
})
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
"raw.githubusercontent.com",
})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
viper.SetDefault("security.url_allowlist.allow_insecure_http", false)
viper.SetDefault("security.response_headers.enabled", false)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
// Turnstile
viper.SetDefault("turnstile.required", false)
// Database
viper.SetDefault("database.host", "localhost")
@@ -329,7 +490,7 @@ func setDefaults() {
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// Default
@@ -357,7 +518,7 @@ func setDefaults() {
viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
@@ -365,19 +526,23 @@ func setDefaults() {
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -396,11 +561,28 @@ func setDefaults() {
}
func (c *Config) Validate() error {
if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required")
if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive")
}
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
if c.JWT.ExpireHour > 168 {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
}
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
}
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
}
if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
}
if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
}
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
@@ -458,6 +640,9 @@ func (c *Config) Validate() error {
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
@@ -467,6 +652,26 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.StreamDataIntervalTimeout < 0 {
return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative")
}
if c.Gateway.StreamDataIntervalTimeout != 0 &&
(c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) {
return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds")
}
if c.Gateway.StreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative")
}
if c.Gateway.StreamKeepaliveInterval != 0 &&
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
@@ -482,9 +687,57 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
return nil
}
func normalizeStringSlice(values []string) []string {
if len(values) == 0 {
return values
}
normalized := make([]string, 0, len(values))
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}
func isWeakJWTSecret(secret string) bool {
lower := strings.ToLower(strings.TrimSpace(secret))
if lower == "" {
return true
}
weak := map[string]struct{}{
"change-me-in-production": {},
"changeme": {},
"secret": {},
"password": {},
"123456": {},
"12345678": {},
"admin": {},
"jwt-secret": {},
}
_, exists := weak[lower]
return exists
}
func generateJWTSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.

View File

@@ -68,3 +68,22 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}
func TestLoadDefaultSecurityToggles(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Security.URLAllowlist.Enabled {
t.Fatalf("URLAllowlist.Enabled = true, want false")
}
if cfg.Security.URLAllowlist.AllowInsecureHTTP {
t.Fatalf("URLAllowlist.AllowInsecureHTTP = true, want false")
}
if cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false")
}
}

View File

@@ -76,6 +76,7 @@ func NewAccountHandler(
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
@@ -91,6 +92,7 @@ type CreateAccountRequest struct {
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
@@ -193,6 +195,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
@@ -249,6 +252,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Notes: req.Notes,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
@@ -357,7 +361,8 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies,
})
if err != nil {
response.ErrorFrom(c, err)
// Provide detailed error message for CRS sync failures
response.InternalError(c, "CRS sync failed: "+err.Error())
return
}

View File

@@ -1,8 +1,12 @@
package admin
import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -34,31 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
})
}
@@ -100,6 +106,10 @@ type UpdateSettingsRequest struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
// UpdateSettings 更新系统设置
@@ -111,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
@@ -129,21 +145,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
// 如果未提供 secret key使用已保存的值留空保留当前值
if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
if previousSettings.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
req.TurnstileSecretKey = previousSettings.TurnstileSecretKey
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
siteKeyChanged := previousSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := previousSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
@@ -178,6 +191,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -185,6 +200,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
@@ -193,34 +210,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
})
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`

View File

@@ -109,6 +109,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
return &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,

View File

@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
@@ -33,6 +33,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
type PublicSettings struct {

View File

@@ -62,6 +62,7 @@ type Group struct {
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`

View File

@@ -11,8 +11,10 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -38,14 +40,19 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
}
}
@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
status, code, message := billingErrorDetails(err)
h.errorResponse(c, status, code, message)
return
}
@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
return http.StatusForbidden, "billing_error", msg
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,8 +27,8 @@ import (
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// defaultPingInterval 流式响应等待时发送 ping 的默认间隔
defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
@@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
}
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}

View File

@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
}
@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
}
}
@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)

View File

@@ -4,13 +4,34 @@ import (
"encoding/json"
"fmt"
"log"
"os"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string
}
func DefaultTransformOptions() TransformOptions {
return TransformOptions{
EnableIdentityPatch: true,
}
}
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
}
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
@@ -22,16 +43,24 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build contents: %w", err)
}
// 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForConfig := claudeReq
if strippedThinking {
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
// disable upstream thinking mode to avoid signature/structure validation errors.
reqCopy := *claudeReq
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
@@ -75,12 +104,8 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
return json.Marshal(v1Req)
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiContent {
var parts []GeminiPart
// 注入身份防护指令
identityPatch := fmt.Sprintf(
func defaultIdentityPatch(modelName string) string {
return fmt.Sprintf(
"--- [IDENTITY_PATCH] ---\n"+
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
"You are currently providing services as the native %s model via a standard API proxy.\n"+
@@ -88,7 +113,20 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
"--- [SYSTEM_PROMPT_BEGIN] ---\n",
modelName,
)
parts = append(parts, GeminiPart{Text: identityPatch})
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart
// 可选注入身份防护指令(身份补丁)
if opts.EnableIdentityPatch {
identityPatch := strings.TrimSpace(opts.IdentityPatch)
if identityPatch == "" {
identityPatch = defaultIdentityPatch(modelName)
}
parts = append(parts, GeminiPart{Text: identityPatch})
}
// 解析 system prompt
if len(system) > 0 {
@@ -111,7 +149,13 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
}
}
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
// identity patch 模式下,用分隔符包裹 system prompt便于上游识别/调试;关闭时尽量保持原始 system prompt。
if opts.EnableIdentityPatch && len(parts) > 0 {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 {
return nil
}
return &GeminiContent{
Role: "user",
@@ -120,8 +164,9 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
}
// buildContents 构建 contents
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, error) {
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, bool, error) {
var contents []GeminiContent
strippedThinking := false
for i, msg := range messages {
role := msg.Role
@@ -129,9 +174,12 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
role = "model"
}
parts, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
parts, strippedThisMsg, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build parts for message %d: %w", i, err)
return nil, false, fmt.Errorf("build parts for message %d: %w", i, err)
}
if strippedThisMsg {
strippedThinking = true
}
// 只有 Gemini 模型支持 dummy thinking block workaround
@@ -165,7 +213,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
})
}
return contents, nil
return contents, strippedThinking, nil
}
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
@@ -174,8 +222,9 @@ const dummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, bool, error) {
var parts []GeminiPart
strippedThinking := false
// 尝试解析为字符串
var textContent string
@@ -183,13 +232,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if textContent != "(no content)" && strings.TrimSpace(textContent) != "" {
parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)})
}
return parts, nil
return parts, false, nil
}
// 解析为内容块数组
var blocks []ContentBlock
if err := json.Unmarshal(content, &blocks); err != nil {
return nil, fmt.Errorf("parse content blocks: %w", err)
return nil, false, fmt.Errorf("parse content blocks: %w", err)
}
for _, block := range blocks {
@@ -208,8 +257,11 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
if strings.TrimSpace(block.Thinking) != "" {
parts = append(parts, GeminiPart{Text: block.Thinking})
}
strippedThinking = true
continue
} else {
// Gemini 模型使用 dummy signature
@@ -276,7 +328,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
}
return parts, nil
return parts, strippedThinking, nil
}
// parseToolResultContent 解析 tool_result 的 content
@@ -446,7 +498,7 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema)
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
@@ -484,6 +536,56 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关方便排查SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出debug/test
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
@@ -546,13 +648,14 @@ var excludedSchemaKeys = map[string]bool{
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any) any {
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
@@ -586,15 +689,15 @@ func cleanSchemaValue(value any) any {
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for _, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned

View File

@@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
description string
}{
{
name: "Claude model - drop thinking without signature",
name: "Claude model - downgrade thinking to text without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // thinking 内容被丢弃
description: "Claude模型应丢弃无signaturethinking block内容",
expectedParts: 3, // thinking 内容降级为普通 text part
description: "Claude模型缺少signature时应将thinking降级为text并在上层禁用thinking mode",
},
{
name: "Claude model - preserve thinking block with signature",
@@ -52,7 +52,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
parts, _, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
@@ -71,6 +71,17 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
case "Claude model - downgrade thinking to text without signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if parts[1].Thought {
t.Fatalf("expected downgraded text part, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
if parts[1].Text != "Let me think..." {
t.Fatalf("expected downgraded text %q, got %q", "Let me think...", parts[1].Text)
}
case "Gemini model - use dummy signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
@@ -91,7 +102,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, true)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
@@ -105,7 +116,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, false)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, false)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}

View File

@@ -25,13 +25,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP防止 DNS Rebinding
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err
}
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{
Transport: transport,
Transport: rt,
Timeout: opts.Timeout,
}, nil
}
@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}

View File

@@ -67,6 +67,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
builder := r.client.Account.Create().
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
@@ -270,6 +271,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
@@ -320,6 +322,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else {
builder.ClearSessionWindowStatus()
}
if account.Notes == nil {
builder.ClearNotes()
}
updated, err := builder.Save(ctx)
if err != nil {
@@ -768,9 +773,14 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
idx++
}
if updates.ProxyID != nil {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *updates.ProxyID == 0 {
setClauses = append(setClauses, "proxy_id = NULL")
} else {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
}
if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
@@ -1065,6 +1075,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),

View File

@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
usageURL string
allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

View File

@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露

View File

@@ -14,18 +14,23 @@ import (
)
type githubReleaseClient struct {
httpClient *http.Client
httpClient *http.Client
allowPrivateHosts bool
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: sharedClient,
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
}
}
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}

View File

@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq)
}
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultIdleConnTimeout: 默认空闲连接超时时间(90秒
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间5分钟
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
@@ -145,6 +150,40 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
if !s.cfg.Security.URLAllowlist.Enabled {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
@@ -232,6 +271,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,

View File

@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例

View File

@@ -26,6 +26,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -16,9 +17,17 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}

View File

@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}

View File

@@ -5,28 +5,52 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)

View File

@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
allowPrivateHosts: true,
}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {

View File

@@ -22,7 +22,8 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
Timeout: 10 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}

View File

@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo
return nil
}
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
Where(dbuser.IDEQ(id)).
AddBalance(-amount).
Save(ctx)
if err != nil {
return err
}
if n == 0 {
return service.ErrInsufficientBalance
return service.ErrUserNotFound
}
return nil
}

View File

@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() {
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
// 透支策略:允许扣除超过余额的金额
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额变为负数
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
s.Require().InDelta(0.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
// 扣除超过余额的金额 - 应该成功
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额为负
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().InDelta(7.5, got4.Balance, 1e-6)
// 透支策略:允许扣除超过余额的金额
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(err, "DeductBalance should allow overdraft")
gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after overdraft")
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
// DeductBalance 在用户不存在时返回 ErrUserNotFound
s.Require().ErrorIs(err, service.ErrUserNotFound)
}

View File

@@ -296,13 +296,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password": "secret",
"smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key",
"turnstile_secret_key_configured": true,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
@@ -315,7 +315,9 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o"
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
}
}`,
},

View File

@@ -2,6 +2,7 @@
package server
import (
"log"
"net/http"
"time"
@@ -36,6 +37,15 @@ func ProvideRouter(
r := gin.New()
r.Use(middleware2.Recovery())
if len(cfg.Server.TrustedProxies) > 0 {
if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
log.Printf("Failed to set trusted proxies: %v", err)
}
} else {
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}

View File

@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果header中没有尝试从query参数中提取Google API key风格
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}

View File

@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return
}
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
}
return ""
}
func allowGoogleQueryKey(path string) bool {
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
}
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{

View File

@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -1,24 +1,103 @@
package middleware
import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var corsWarningOnce sync.Once
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
if c.Request.Method == http.MethodOptions {
if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return
}
c.Next()
}
}
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
}
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}

View File

@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
r.Use(middleware2.CORS())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {

View File

@@ -11,6 +11,7 @@ import (
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string {
return out
}
func normalizeAccountNotes(value *string) *string {
if value == nil {
return nil
}
trimmed := strings.TrimSpace(*value)
if trimmed == "" {
return nil
}
return &trimmed
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:

View File

@@ -72,6 +72,7 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
@@ -85,6 +86,7 @@ type CreateAccountRequest struct {
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
// 创建账号
account := &Account{
Name: req.Name,
Notes: normalizeAccountNotes(req.Notes),
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Name != nil {
account.Name = *req.Name
}
if req.Notes != nil {
account.Notes = normalizeAccountNotes(req.Notes)
}
if req.Credentials != nil {
account.Credentials = *req.Credentials

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -14,9 +15,11 @@ import (
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -53,15 +57,35 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
if !s.cfg.Security.URLAllowlist.Enabled {
return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
baseURL := account.GetBaseURL()
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID)
strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {

View File

@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type UpdateAccountInput struct {
Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform,
Type: input.Type,
Credentials: input.Credentials,
@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" {
account.Type = input.Type
}
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Extra = input.Extra
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0

View File

@@ -9,8 +9,10 @@ import (
"fmt"
"io"
"log"
mathrand "math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
}
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
opts := antigravity.DefaultTransformOptions()
if s.settingService == nil {
return opts
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
return opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本
func extractGeminiResponseText(respBody []byte) string {
var resp map[string]any
@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel)
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
}
@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil {
return nil, err
@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
retryStages := []struct {
name string
strip func(*antigravity.ClaudeRequest) (bool, error)
}{
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
}
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
if txErr == nil {
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// Retry success: continue normal success flow with the new response.
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
// Retry still errored: replace error context with retry response.
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = retryResp
}
} else {
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
}
for _, stage := range retryStages {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stage.strip(&retryClaudeReq)
if stripErr != nil || !stripped {
continue
}
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if txErr != nil {
continue
}
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr != nil {
continue
}
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr != nil {
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
break
}
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
break
}
// Still signature-related; capture context and allow next stage.
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
}
}
@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false
}
func extractAntigravityErrorMessage(body []byte) string {
@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
modifiedAny = true
case "":
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
if len(filtered) == 0 {
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "tool_use":
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name, _ := block["name"].(string)
id, _ := block["id"].(string)
input := block["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "tool_result":
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID, _ := block["tool_use_id"].(string)
isError, _ := block["is_error"].(bool)
content := block["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
if len(filtered) == 0 {
// Keep request valid: upstream rejects empty content arrays.
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil {
return nil, err
@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
break
}
defer func() { _ = resp.Body.Close() }()
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应释放连接respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
func sleepAntigravityBackoff(attempt int) {
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
select {
case <-ctx.Done():
return false
case <-time.After(sleepFor):
return true
}
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
if len(line) > 0 {
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
select {
case ev, ok := <-events:
if !ok {
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数

View File

@@ -0,0 +1,83 @@
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`),
},
},
}
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
require.Len(t, req.Messages, 2)
var blocks0 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
require.Len(t, blocks0, 2)
require.Equal(t, "text", blocks0[0]["type"])
require.Equal(t, "secret plan", blocks0[0]["text"])
require.Equal(t, "text", blocks0[1]["type"])
var blocks1 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
require.Len(t, blocks1, 1)
require.Equal(t, "text", blocks1[0]["type"])
require.NotEmpty(t, blocks1[0]["text"])
}
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
},
},
}
changed, err := stripThinkingFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
var blocks []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
require.Len(t, blocks, 2)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}

View File

@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证
}
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}

View File

@@ -16,7 +16,8 @@ import (
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
type billingCircuitBreakerState int
if !subscription.IsActive() {
return ErrSubscriptionInvalid
}
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
if !subscription.CheckDailyLimit(group, 0) {
return ErrDailyLimitExceeded
}
if !subscription.CheckWeeklyLimit(group, 0) {
return ErrWeeklyLimitExceeded
}
if !subscription.CheckMonthlyLimit(group, 0) {
return ErrMonthlyLimitExceeded
}
return nil
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}

View File

@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
cfg: cfg,
}
}
@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
return nil, err
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
Timeout: 20 * time.Second,
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist := len(allowlist) > 0
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: allowlist,
RequireAllowlist: requireAllowlist,
AllowPrivate: allowPrivate,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials

View File

@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -84,25 +84,37 @@ func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
// Why:
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
// final message is an assistant prefill, the assistant content must start with a thinking block.
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
// Strategy (B: preserve content as text):
// - Disable top-level `thinking` (remove `thinking` field).
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
// - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"thinking":`)) ||
bytes.Contains(body, []byte(`"thinking" :`))
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
bytes.Contains(body, []byte(`"content": []`)) ||
bytes.Contains(body, []byte(`"content" : []`)) ||
bytes.Contains(body, []byte(`"content" :[]`))
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
return body
}
@@ -111,15 +123,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
modified := false
messages, ok := req["messages"].([]any)
if !ok {
return body
}
modified := false
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
@@ -149,33 +165,59 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if blockType == "thinking" || blockType == "redacted_thinking" {
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType {
case "thinking":
modifiedThisMsg = true
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": thinkingText,
})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
}
// Handle blocks without type discriminator but with a "thinking" field.
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
// Handle empty content: either from filtering or originally empty
if len(newContent) == 0 {
modified = true
// Handle empty content after filtering
if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if role == "user" {
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
}
// Skip assistant messages with empty content (don't append)
continue
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": placeholder,
})
msgMap["content"] = newContent
} else if modifiedThisMsg {
modified = true
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
@@ -183,6 +225,9 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
if modified {
req["messages"] = newMessages
} else {
// Avoid rewriting JSON when no changes are needed.
return body
}
newBody, err := json.Marshal(req)
@@ -192,6 +237,172 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return newBody
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
// This performs everything in FilterThinkingBlocksForRetry, plus:
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
//
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
// risk of prompt injection (tool output becomes plain conversation text).
func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
// Fast path: only run when we see likely relevant constructs.
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_use"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_use"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_result"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_result"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
modified := false
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
messages, ok := req["messages"].([]any)
if !ok {
return body
}
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "thinking":
modifiedThisMsg = true
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
case "tool_use":
modifiedThisMsg = true
name, _ := blockMap["name"].(string)
id, _ := blockMap["id"].(string)
input := blockMap["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue
case "tool_result":
modifiedThisMsg = true
toolUseID, _ := blockMap["tool_use_id"].(string)
isError, _ := blockMap["is_error"].(bool)
content := blockMap["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue
}
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
modified = true
if len(newContent) == 0 {
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{"type": "text", "text": placeholder})
}
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if !modified {
return body
}
req["messages"] = newMessages
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks

View File

@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
})
}
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
require.Len(t, msgs, 2)
assistant, ok := msgs[1].(map[string]any)
require.True(t, ok)
content, ok := assistant["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", first["type"])
require.Equal(t, "Let me think...", first["text"])
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
}
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "Visible", content0["text"])
}
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.NotEmpty(t, content0["text"])
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
content1, ok := content[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "text", content1["type"])
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}

View File

@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
// 重试相关常量
const (
maxRetries = 10 // 最大试次数
retryDelay = 3 * time.Second // 重试等待时间
// 最大试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
maxRetryAttempts = 5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay = 300 * time.Millisecond
retryMaxDelay = 3 * time.Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed = 10 * time.Second
)
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
}
func retryBackoffDelay(attempt int) time.Duration {
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if attempt <= 0 {
return retryBaseDelay
}
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
if delay > retryMaxDelay {
return retryMaxDelay
}
return delay
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断User-Agent 匹配 + metadata.user_id 存在
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
return nil, fmt.Errorf("upstream request failed: %w", err)
}
@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
looksLikeToolSignatureError := func(msg string) bool {
m := strings.ToLower(msg)
return strings.Contains(m, "tool_use") ||
strings.Contains(m, "tool_result") ||
strings.Contains(m, "functioncall") ||
strings.Contains(m, "function_call") ||
strings.Contains(m, "functionresponse") ||
strings.Contains(m, "function_response")
}
// 避免在重试预算已耗尽时再发起额外请求
if time.Since(retryStart) >= maxRetryElapsed {
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试使用更激进的过滤
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
resp = retryResp
break
}
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
msg2 := extractUpstreamErrorMessage(retryRespBody)
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
resp = retryResp2
break
}
if retryResp2 != nil && retryResp2.Body != nil {
_ = retryResp2.Body.Close()
}
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
} else {
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
}
}
}
// Fall back to the original retry response context.
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
resp = retryResp
break
}
if retryResp != nil && retryResp.Body != nil {
_ = retryResp.Body.Close()
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
// Retry failed: restore original response body and continue handling.
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
if attempt < maxRetryAttempts {
elapsed := time.Since(retryStart)
if elapsed >= maxRetryElapsed {
break
}
delay := retryBackoffDelay(attempt)
remaining := maxRetryElapsed - elapsed
if delay > remaining {
delay = remaining
}
if delay <= 0 {
break
}
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
_ = resp.Body.Close()
time.Sleep(retryDelay)
if err := sleepWithContext(ctx, delay); err != nil {
return nil, err
}
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
break
}
if resp == nil || resp.Body == nil {
return nil, errors.New("upstream request failed: empty response")
}
defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况
@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
}
// OAuth账号应用统一指纹
@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
}
}
@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
flusher.Flush()
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil
}
@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
originalClaudeBody := body
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
var resp *http.Response
signatureRetryStage := 0
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
case 0:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
stageName = "thinking-only"
signatureRetryStage = 1
default:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
stageName = "thinking+tools"
signatureRetryStage = 2
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
continue
}
}
// Restore body for downstream error handling.
resp = &http.Response{
StatusCode: http.StatusBadRequest,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil
}
func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
}
log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Headers: filteredHeaders,
Body: body,
}, nil
}

View File

@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
targetURL = baseURL + "/responses"
} else {
if baseURL == "" {
targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
line := ev.line
lastDataAt = time.Now()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Pass through headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, "application/json", body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {

View File

@@ -0,0 +1,286 @@
package service
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
start := time.Now()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
_ = pw.Close()
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err)
}
if !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: 64 * 1024,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
_, _ = pw.Write([]byte(payload))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
_ = pr.Close()
if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err)
}
if !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
}
}
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{
"Cache-Control": []string{"upstream"},
"X-Request-Id": []string{"req-123"},
"Content-Type": []string{"application/custom"},
},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("handleStreamingResponse error: %v", err)
}
if rec.Header().Get("Cache-Control") != "no-cache" {
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
}
if rec.Header().Get("Content-Type") != "text/event-stream" {
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
}
if rec.Header().Get("X-Request-Id") != "req-123" {
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
}
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
}
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
if err != nil {
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
if err != nil {
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
}
if normalized != "http://not-https.example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: true,
UpstreamHosts: []string{"example.com"},
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
t.Fatalf("expected non-allowlisted host to fail")
}
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
// computeFileHash 计算文件哈希

View File

@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
// Identity patch configuration (Claude -> Gemini)
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
return s.settingRepo.SetMultiple(ctx, updates)
}
@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
// Identity patch defaults
SettingKeyEnableIdentityPatch: "true",
SettingKeyIdentityPatchPrompt: "",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
// Identity patch settings (default: enabled, to preserve existing behavior)
if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" {
result.EnableIdentityPatch = v == "true"
} else {
result.EnableIdentityPatch = true
}
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
return result
}
@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// IsIdentityPatchEnabled 检查是否启用身份补丁Claude -> Gemini systemInstruction 注入)
func (s *SettingService) IsIdentityPatchEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableIdentityPatch)
if err != nil {
// 默认开启,保持兼容
return true
}
return value == "true"
}
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
func (s *SettingService) GetIdentityPatchPrompt(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyIdentityPatchPrompt)
if err != nil {
return ""
}
return value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符

View File

@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPPasswordConfigured bool
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string
@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
type PublicSettings struct {

View File

@@ -21,10 +21,44 @@ import (
// Config paths
const (
ConfigFile = "config.yaml"
EnvFile = ".env"
ConfigFileName = "config.yaml"
InstallLockFile = ".installed"
)
// GetDataDir returns the data directory for storing config and lock files.
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
func GetDataDir() string {
// Check DATA_DIR environment variable first
if dir := os.Getenv("DATA_DIR"); dir != "" {
return dir
}
// Check if /app/data exists and is writable (Docker environment)
dockerDataDir := "/app/data"
if info, err := os.Stat(dockerDataDir); err == nil && info.IsDir() {
// Try to check if writable by creating a temp file
testFile := dockerDataDir + "/.write_test"
if f, err := os.Create(testFile); err == nil {
_ = f.Close()
_ = os.Remove(testFile)
return dockerDataDir
}
}
// Default to current directory
return "."
}
// GetConfigFilePath returns the full path to config.yaml
func GetConfigFilePath() string {
return GetDataDir() + "/" + ConfigFileName
}
// GetInstallLockPath returns the full path to .installed lock file
func GetInstallLockPath() string {
return GetDataDir() + "/" + InstallLockFile
}
// SetupConfig holds the setup configuration
type SetupConfig struct {
Database DatabaseConfig `json:"database" yaml:"database"`
@@ -71,13 +105,12 @@ type JWTConfig struct {
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func NeedsSetup() bool {
// Check 1: Config file must not exist
if _, err := os.Stat(ConfigFile); !os.IsNotExist(err) {
if _, err := os.Stat(GetConfigFilePath()); !os.IsNotExist(err) {
return false // Config exists, no setup needed
}
// Check 2: Installation lock file (harder to bypass)
lockFile := ".installed"
if _, err := os.Stat(lockFile); !os.IsNotExist(err) {
if _, err := os.Stat(GetInstallLockPath()); !os.IsNotExist(err) {
return false // Lock file exists, already installed
}
@@ -201,6 +234,7 @@ func Install(cfg *SetupConfig) error {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Test connections
@@ -237,9 +271,8 @@ func Install(cfg *SetupConfig) error {
// createInstallLock creates a lock file to prevent re-installation attacks
func createInstallLock() error {
lockFile := ".installed"
content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339))
return os.WriteFile(lockFile, []byte(content), 0400) // Read-only for owner
return os.WriteFile(GetInstallLockPath(), []byte(content), 0400) // Read-only for owner
}
func initializeDatabase(cfg *SetupConfig) error {
@@ -390,7 +423,7 @@ func writeConfigFile(cfg *SetupConfig) error {
return err
}
return os.WriteFile(ConfigFile, data, 0600)
return os.WriteFile(GetConfigFilePath(), data, 0600)
}
func generateSecret(length int) (string, error) {
@@ -433,6 +466,7 @@ func getEnvIntOrDefault(key string, defaultValue int) int {
// This is designed for Docker deployment where all config is passed via env vars
func AutoSetupFromEnv() error {
log.Println("Auto setup enabled, configuring from environment variables...")
log.Printf("Data directory: %s", GetDataDir())
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz := getEnvOrDefault("TZ", "")
@@ -479,7 +513,7 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Generated JWT secret automatically")
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Generate admin password if not provided
@@ -489,8 +523,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err)
}
cfg.Admin.Password = password
log.Printf("Generated admin password: %s", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.")
fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection

View File

@@ -0,0 +1,100 @@
package logredact
import (
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const maxRedactDepth = 32
var defaultSensitiveKeys = map[string]struct{}{
"authorization_code": {},
"code": {},
"code_verifier": {},
"access_token": {},
"refresh_token": {},
"id_token": {},
"client_secret": {},
"password": {},
}
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
if input == nil {
return map[string]any{}
}
keys := buildKeySet(extraKeys)
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
if !ok {
return map[string]any{}
}
return redacted
}
func RedactJSON(raw []byte, extraKeys ...string) string {
if len(raw) == 0 {
return ""
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return "<non-json payload redacted>"
}
keys := buildKeySet(extraKeys)
redacted := redactValueWithDepth(value, keys, 0)
encoded, err := json.Marshal(redacted)
if err != nil {
return "<redacted>"
}
return string(encoded)
}
func buildKeySet(extraKeys []string) map[string]struct{} {
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
for k := range defaultSensitiveKeys {
keys[k] = struct{}{}
}
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
keys[normalized] = struct{}{}
}
return keys
}
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
if depth > maxRedactDepth {
return "<depth limit exceeded>"
}
switch v := value.(type) {
case map[string]any:
out := make(map[string]any, len(v))
for k, val := range v {
if isSensitiveKey(k, keys) {
out[k] = "***"
continue
}
out[k] = redactValueWithDepth(val, keys, depth+1)
}
return out
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = redactValueWithDepth(item, keys, depth+1)
}
return out
default:
return value
}
}
func isSensitiveKey(key string, keys map[string]struct{}) bool {
_, ok := keys[normalizeKey(key)]
return ok
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}

View File

@@ -0,0 +1,99 @@
package responseheaders
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var defaultAllowed = map[string]struct{}{
"content-type": {},
"content-encoding": {},
"content-language": {},
"cache-control": {},
"etag": {},
"last-modified": {},
"expires": {},
"vary": {},
"date": {},
"x-request-id": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
"www-authenticate": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var hopByHopHeaders = map[string]struct{}{
"content-length": {},
"transfer-encoding": {},
"connection": {},
}
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
}
// 关闭时只使用默认白名单additional/force_remove 不生效
if cfg.Enabled {
for _, key := range cfg.AdditionalAllowed {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
allowed[normalized] = struct{}{}
}
}
forceRemove := map[string]struct{}{}
if cfg.Enabled {
forceRemove = make(map[string]struct{}, len(cfg.ForceRemove))
for _, key := range cfg.ForceRemove {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
forceRemove[normalized] = struct{}{}
}
}
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, blocked := forceRemove[lower]; blocked {
continue
}
if _, ok := allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
filtered := FilterHeaders(src, cfg)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
}
}
}

View File

@@ -0,0 +1,67 @@
package responseheaders
import (
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Request-Id", "req-123")
src.Add("X-Test", "ok")
src.Add("Connection", "keep-alive")
src.Add("Content-Length", "123")
cfg := config.ResponseHeaderConfig{
Enabled: false,
ForceRemove: []string{"x-request-id"},
}
filtered := FilterHeaders(src, cfg)
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Request-Id") != "req-123" {
t.Fatalf("expected X-Request-Id allowed, got %q", filtered.Get("X-Request-Id"))
}
if filtered.Get("X-Test") != "" {
t.Fatalf("expected X-Test removed, got %q", filtered.Get("X-Test"))
}
if filtered.Get("Connection") != "" {
t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection"))
}
if filtered.Get("Content-Length") != "" {
t.Fatalf("expected Content-Length to be removed, got %q", filtered.Get("Content-Length"))
}
}
func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Extra", "ok")
src.Add("X-Remove", "nope")
src.Add("X-Blocked", "nope")
cfg := config.ResponseHeaderConfig{
Enabled: true,
AdditionalAllowed: []string{"x-extra"},
ForceRemove: []string{"x-remove"},
}
filtered := FilterHeaders(src, cfg)
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Extra") != "ok" {
t.Fatalf("expected X-Extra allowed, got %q", filtered.Get("X-Extra"))
}
if filtered.Get("X-Remove") != "" {
t.Fatalf("expected X-Remove removed, got %q", filtered.Get("X-Remove"))
}
if filtered.Get("X-Blocked") != "" {
t.Fatalf("expected X-Blocked removed, got %q", filtered.Get("X-Blocked"))
}
}

View File

@@ -0,0 +1,154 @@
package urlvalidator
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
)
type ValidationOptions struct {
AllowedHosts []string
RequireAllowlist bool
AllowPrivate bool
}
func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
// 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return "", errors.New("invalid host")
}
if port := parsed.Port(); port != "" {
num, err := strconv.Atoi(port)
if err != nil || num <= 0 || num > 65535 {
return "", fmt.Errorf("invalid port: %s", port)
}
}
return trimmed, nil
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid host")
}
if !opts.AllowPrivate && isBlockedHost(host) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
allowlist := normalizeAllowlist(opts.AllowedHosts)
if opts.RequireAllowlist && len(allowlist) == 0 {
return "", errors.New("allowlist is not configured")
}
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
parsed.Path = strings.TrimRight(parsed.Path, "/")
parsed.RawPath = ""
return strings.TrimRight(parsed.String(), "/"), nil
}
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
func ValidateResolvedIP(host string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return fmt.Errorf("dns resolution failed: %w", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
}
}
return nil
}
func normalizeAllowlist(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, v := range values {
entry := strings.ToLower(strings.TrimSpace(v))
if entry == "" {
continue
}
if host, _, err := net.SplitHostPort(entry); err == nil {
entry = host
}
normalized = append(normalized, entry)
}
return normalized
}
func isAllowedHost(host string, allowlist []string) bool {
for _, entry := range allowlist {
if entry == "" {
continue
}
if strings.HasPrefix(entry, "*.") {
suffix := strings.TrimPrefix(entry, "*.")
if host == suffix || strings.HasSuffix(host, "."+suffix) {
return true
}
continue
}
if host == entry {
return true
}
}
return false
}
func isBlockedHost(host string) bool {
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
package urlvalidator
import "testing"
func TestValidateURLFormat(t *testing.T) {
if _, err := ValidateURLFormat("", false); err == nil {
t.Fatalf("expected empty url to fail")
}
if _, err := ValidateURLFormat("://bad", false); err == nil {
t.Fatalf("expected invalid url to fail")
}
if _, err := ValidateURLFormat("http://example.com", false); err == nil {
t.Fatalf("expected http to fail when allow_insecure_http is false")
}
if _, err := ValidateURLFormat("https://example.com", false); err != nil {
t.Fatalf("expected https to pass, got %v", err)
}
if _, err := ValidateURLFormat("http://example.com", true); err != nil {
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
}
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
t.Fatalf("expected invalid port to fail")
}
}

View File

@@ -0,0 +1,7 @@
-- 028_add_account_notes.sql
-- Add optional admin notes for accounts.
ALTER TABLE accounts
ADD COLUMN IF NOT EXISTS notes TEXT;
COMMENT ON COLUMN accounts.notes IS 'Admin-only notes for account';

View File

@@ -54,10 +54,21 @@ ADMIN_PASSWORD=
# -----------------------------------------------------------------------------
# JWT Configuration
# -----------------------------------------------------------------------------
# Leave empty to auto-generate (recommended)
# IMPORTANT: Set a fixed JWT_SECRET to prevent login sessions from being
# invalidated after container restarts. If left empty, a random secret will
# be generated on each startup, causing all users to be logged out.
# Generate a secure secret: openssl rand -hex 32
JWT_SECRET=
JWT_EXPIRE_HOUR=24
# -----------------------------------------------------------------------------
# Configuration File (Optional)
# -----------------------------------------------------------------------------
# Path to custom config file (relative to docker-compose.yml directory)
# Copy config.example.yaml to config.yaml and modify as needed
# Leave unset to use default ./config.yaml
#CONFIG_FILE=./config.yaml
# -----------------------------------------------------------------------------
# Gemini OAuth (OPTIONAL, required only for Gemini OAuth accounts)
# -----------------------------------------------------------------------------

View File

@@ -25,8 +25,8 @@
timeouts {
read_body 30s
read_header 10s
write 60s
idle 120s
write 300s
idle 300s
}
}
}
@@ -77,7 +77,10 @@ example.com {
write_buffer 16KB
compression off
}
# SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端
flush_interval -1
# 故障转移
fail_duration 30s
max_fails 3
@@ -92,6 +95,10 @@ example.com {
gzip 6
minimum_length 256
match {
# SSE 请求通常会带 Accept: text/event-stream需排除压缩
not header Accept text/event-stream*
# 排除已知 SSE 路径(即便 Accept 缺失)
not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/*
header Content-Type text/*
header Content-Type application/json*
header Content-Type application/javascript*
@@ -179,6 +186,3 @@ example.com {
# =============================================================================
# HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明)
# =============================================================================
; http://example.com {
; redir https://{host}{uri} permanent
; }

View File

@@ -1,174 +1,390 @@
# Sub2API Configuration File
# Sub2API 配置文件
#
# Copy this file to /etc/sub2api/config.yaml and modify as needed
# Documentation: https://github.com/Wei-Shaw/sub2api
# 复制此文件到 /etc/sub2api/config.yaml 并根据需要修改
#
# Documentation / 文档: https://github.com/Wei-Shaw/sub2api
# =============================================================================
# Server Configuration
# 服务器配置
# =============================================================================
server:
# Bind address (0.0.0.0 for all interfaces)
# 绑定地址0.0.0.0 表示监听所有网络接口)
host: "0.0.0.0"
# Port to listen on
# 监听端口
port: 8080
# Mode: "debug" for development, "release" for production
# 运行模式:"debug" 用于开发,"release" 用于生产环境
mode: "release"
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
# 信任的代理地址CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
trusted_proxies: []
# =============================================================================
# Run Mode Configuration
# 运行模式配置
# =============================================================================
# Run mode: "standard" (default) or "simple" (for internal use)
# 运行模式:"standard"(默认)或 "simple"(内部使用)
# - standard: Full SaaS features with billing/balance checks
# - standard: 完整 SaaS 功能,包含计费和余额校验
# - simple: Hides SaaS features and skips billing/balance checks
# - simple: 隐藏 SaaS 功能,跳过计费和余额校验
run_mode: "standard"
# =============================================================================
# CORS Configuration
# 跨域资源共享 (CORS) 配置
# =============================================================================
cors:
# Allowed origins list. Leave empty to disable cross-origin requests.
# 允许的来源列表。留空则禁用跨域请求。
allowed_origins: []
# Allow credentials (cookies/authorization headers). Cannot be used with "*".
# 允许携带凭证cookies/授权头)。不能与 "*" 通配符同时使用。
allow_credentials: true
# =============================================================================
# Security Configuration
# 安全配置
# =============================================================================
security:
url_allowlist:
# Enable URL allowlist validation (disable to skip all URL checks)
# 启用 URL 白名单验证(禁用则跳过所有 URL 检查)
enabled: false
# Allowed upstream hosts for API proxying
# 允许代理的上游 API 主机列表
upstream_hosts:
- "api.openai.com"
- "api.anthropic.com"
- "api.kimi.com"
- "open.bigmodel.cn"
- "api.minimaxi.com"
- "generativelanguage.googleapis.com"
- "cloudcode-pa.googleapis.com"
- "*.openai.azure.com"
# Allowed hosts for pricing data download
# 允许下载定价数据的主机列表
pricing_hosts:
- "raw.githubusercontent.com"
# Allowed hosts for CRS sync (required when using CRS sync)
# 允许 CRS 同步的主机列表(使用 CRS 同步功能时必须配置)
crs_hosts: []
# Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks)
# 允许本地/私有 IP 地址用于上游/定价/CRS仅在可信网络中使用
allow_private_hosts: true
# Allow http:// URLs when allowlist is disabled (default: false, require https)
# 白名单禁用时是否允许 http:// URL默认: false要求 https
allow_insecure_http: true
response_headers:
# Enable configurable response header filtering (disable to use default allowlist)
# 启用可配置的响应头过滤(禁用则使用默认白名单)
enabled: false
# Extra allowed response headers from upstream
# 额外允许的上游响应头
additional_allowed: []
# Force-remove response headers from upstream
# 强制移除的上游响应头
force_remove: []
csp:
# Enable Content-Security-Policy header
# 启用内容安全策略 (CSP) 响应头
enabled: true
# Default CSP policy (override if you host assets on other domains)
# 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖)
policy: "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
proxy_probe:
# Allow skipping TLS verification for proxy probe (debug only)
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
insecure_skip_verify: false
# =============================================================================
# Gateway Configuration
# 网关配置
# =============================================================================
gateway:
# Timeout for waiting upstream response headers (seconds)
# 等待上游响应头超时时间(秒)
response_header_timeout: 300
response_header_timeout: 600
# Max request body size in bytes (default: 100MB)
# 请求体最大字节数(默认 100MB
max_body_size: 104857600
# Connection pool isolation strategy:
# 连接池隔离策略:
# - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts)
# - proxy: 按代理隔离,同一代理共享连接池(适合代理少、账户多)
# - account: Isolate by account, same account shares connection pool (suitable for few accounts, strict isolation)
# - account: 按账户隔离,同一账户共享连接池(适合账户少、需严格隔离)
# - account_proxy: Isolate by account+proxy combination (default, finest granularity)
# - account_proxy: 按账户+代理组合隔离(默认,最细粒度)
connection_pool_isolation: "account_proxy"
# HTTP 上游连接池配置HTTP/2 + 多代理场景默认)
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts
# 所有主机的最大空闲连接数
max_idle_conns: 240
# Max idle connections per host
# 每个主机的最大空闲连接数
max_idle_conns_per_host: 120
# Max connections per host
# 每个主机的最大连接数
max_conns_per_host: 240
idle_conn_timeout_seconds: 300
# Idle connection timeout (seconds)
# 空闲连接超时时间(秒)
idle_conn_timeout_seconds: 90
# Upstream client cache settings
# 上游连接池客户端缓存配置
# max_upstream_clients: Max cached clients, evicts least recently used when exceeded
# max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的
# client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收
max_upstream_clients: 5000
# client_idle_ttl_seconds: Client idle reclaim threshold (seconds), reclaimed when idle and no active requests
# client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收
client_idle_ttl_seconds: 900
# Concurrency slot expiration time (minutes)
# 并发槽位过期时间(分钟)
concurrency_slot_ttl_minutes: 15
concurrency_slot_ttl_minutes: 30
# Stream data interval timeout (seconds), 0=disable
# 流数据间隔超时0=禁用
stream_data_interval_timeout: 180
# Stream keepalive interval (seconds), 0=disable
# 流式 keepalive 间隔0=禁用
stream_keepalive_interval: 10
# SSE max line size in bytes (default: 10MB)
# SSE 单行最大字节数(默认 10MB
max_line_size: 10485760
# Log upstream error response body summary (safe/truncated; does not log request content)
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
log_upstream_error_body: false
# Max bytes to log from upstream error body
# 记录上游错误响应体的最大字节数
log_upstream_error_body_max_bytes: 2048
# Auto inject anthropic-beta header for API-key accounts when needed (default: off)
# 需要时自动为 API-key 账户注入 anthropic-beta 头(默认:关闭)
inject_beta_for_apikey: false
# Allow failover on selected 400 errors (default: off)
# 允许在特定 400 错误时进行故障转移(默认:关闭)
failover_on_400: false
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置
# =============================================================================
concurrency:
# SSE ping interval during concurrency wait (seconds)
# 并发等待期间的 SSE ping 间隔(秒)
ping_interval: 10
# =============================================================================
# Database Configuration (PostgreSQL)
# 数据库配置 (PostgreSQL)
# =============================================================================
database:
# Database host address
# 数据库主机地址
host: "localhost"
# Database port
# 数据库端口
port: 5432
# Database username
# 数据库用户名
user: "postgres"
# Database password
# 数据库密码
password: "your_secure_password_here"
# Database name
# 数据库名称
dbname: "sub2api"
# SSL mode: disable, require, verify-ca, verify-full
# SSL 模式disable禁用, require要求, verify-ca验证CA, verify-full完全验证
sslmode: "disable"
# =============================================================================
# Redis Configuration
# Redis 配置
# =============================================================================
redis:
# Redis host address
# Redis 主机地址
host: "localhost"
# Redis port
# Redis 端口
port: 6379
# Leave empty if no password is set
# Redis password (leave empty if no password is set)
# Redis 密码(如果未设置密码则留空)
password: ""
# Database number (0-15)
# 数据库编号0-15
db: 0
# =============================================================================
# JWT Configuration
# JWT 配置
# =============================================================================
jwt:
# IMPORTANT: Change this to a random string in production!
# Generate with: openssl rand -hex 32
# 重要:生产环境中请更改为随机字符串!
# Generate with / 生成命令: openssl rand -hex 32
secret: "change-this-to-a-secure-random-string"
# Token expiration time in hours
# Token expiration time in hours (max 24)
# 令牌过期时间(小时,最大 24
expire_hour: 24
# =============================================================================
# Default Settings
# 默认设置
# =============================================================================
default:
# Initial admin account (created on first run)
# 初始管理员账户(首次运行时创建)
admin_email: "admin@example.com"
admin_password: "admin123"
# Default settings for new users
user_concurrency: 5 # Max concurrent requests per user
user_balance: 0 # Initial balance for new users
# 新用户默认设置
# Max concurrent requests per user
# 每用户最大并发请求数
user_concurrency: 5
# Initial balance for new users
# 新用户初始余额
user_balance: 0
# API key settings
api_key_prefix: "sk-" # Prefix for generated API keys
# API 密钥设置
# Prefix for generated API keys
# 生成的 API 密钥前缀
api_key_prefix: "sk-"
# Rate multiplier (affects billing calculation)
# 费率倍数(影响计费计算)
rate_multiplier: 1.0
# =============================================================================
# Rate Limiting
# 速率限制
# =============================================================================
rate_limit:
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟
overload_cooldown_minutes: 10
# =============================================================================
# Pricing Data Source (Optional)
# 定价数据源(可选)
# =============================================================================
pricing:
# URL to fetch model pricing data (default: LiteLLM)
# 获取模型定价数据的 URL默认LiteLLM
remote_url: "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
# Hash verification URL (optional)
# 哈希校验 URL可选
hash_url: ""
# Local data directory for caching
# 本地数据缓存目录
data_dir: "./data"
# Fallback pricing file
# 备用定价文件
fallback_file: "./resources/model-pricing/model_prices_and_context_window.json"
# Update interval in hours
# 更新间隔(小时)
update_interval_hours: 24
# Hash check interval in minutes
# 哈希检查间隔(分钟)
hash_check_interval_minutes: 10
# =============================================================================
# Gateway (Optional)
# Billing Configuration
# 计费配置
# =============================================================================
gateway:
# Wait time (in seconds) for upstream response headers (streaming body not affected)
response_header_timeout: 300
# Log upstream error response body summary (safe/truncated; does not log request content)
log_upstream_error_body: false
# Max bytes to log from upstream error body
log_upstream_error_body_max_bytes: 2048
# Auto inject anthropic-beta for API-key accounts when needed (default off)
inject_beta_for_apikey: false
# Allow failover on selected 400 errors (default off)
failover_on_400: false
billing:
circuit_breaker:
# Enable circuit breaker for billing service
# 启用计费服务熔断器
enabled: true
# Number of failures before opening circuit
# 触发熔断的失败次数阈值
failure_threshold: 5
# Time to wait before attempting reset (seconds)
# 熔断后重试等待时间(秒)
reset_timeout_seconds: 30
# Number of requests to allow in half-open state
# 半开状态允许通过的请求数
half_open_requests: 3
# =============================================================================
# Turnstile Configuration
# Turnstile 人机验证配置
# =============================================================================
turnstile:
# Require Turnstile in release mode (when enabled, login/register will fail if not configured)
# 在 release 模式下要求 Turnstile 验证(启用后,若未配置则登录/注册会失败)
required: false
# =============================================================================
# Gemini OAuth (Required for Gemini accounts)
# Gemini OAuth 配置Gemini 账户必需)
# =============================================================================
# Sub2API supports TWO Gemini OAuth modes:
# Sub2API 支持两种 Gemini OAuth 模式:
#
# 1. Code Assist OAuth (需要 GCP project_id)
# 1. Code Assist OAuth (requires GCP project_id)
# 1. Code Assist OAuth需要 GCP project_id
# - Uses: cloudcode-pa.googleapis.com (Code Assist API)
# - 使用cloudcode-pa.googleapis.comCode Assist API
#
# 2. AI Studio OAuth (不需要 project_id)
# 2. AI Studio OAuth (no project_id needed)
# 2. AI Studio OAuth不需要 project_id
# - Uses: generativelanguage.googleapis.com (AI Studio API)
# - 使用generativelanguage.googleapis.comAI Studio API
#
# Default: Uses Gemini CLI's public OAuth credentials (same as Google's official CLI tool)
# 默认:使用 Gemini CLI 的公开 OAuth 凭证(与 Google 官方 CLI 工具相同)
gemini:
oauth:
# Gemini CLI public OAuth credentials (works for both Code Assist and AI Studio)
# Gemini CLI 公开 OAuth 凭证(适用于 Code Assist 和 AI Studio
client_id: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
client_secret: "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"
# Optional scopes (space-separated). Leave empty to auto-select based on oauth_type.
# 可选的权限范围(空格分隔)。留空则根据 oauth_type 自动选择。
scopes: ""
quota:
# Optional: local quota simulation for Gemini Code Assist (local billing).
# 可选Gemini Code Assist 本地配额模拟(本地计费)。
# These values are used for UI progress + precheck scheduling, not official Google quotas.
# 这些值用于 UI 进度显示和预检调度,并非 Google 官方配额。
tiers:
LEGACY:
# Pro model requests per day
# Pro 模型每日请求数
pro_rpd: 50
# Flash model requests per day
# Flash 模型每日请求数
flash_rpd: 1500
# Cooldown time (minutes) after hitting quota
# 达到配额后的冷却时间(分钟)
cooldown_minutes: 30
PRO:
# Pro model requests per day
# Pro 模型每日请求数
pro_rpd: 1500
# Flash model requests per day
# Flash 模型每日请求数
flash_rpd: 4000
# Cooldown time (minutes) after hitting quota
# 达到配额后的冷却时间(分钟)
cooldown_minutes: 5
ULTRA:
# Pro model requests per day
# Pro 模型每日请求数
pro_rpd: 2000
# Flash model requests per day (0 = unlimited)
# Flash 模型每日请求数0 = 无限制)
flash_rpd: 0
# Cooldown time (minutes) after hitting quota
# 达到配额后的冷却时间(分钟)
cooldown_minutes: 5

View File

@@ -1,12 +1,13 @@
# =============================================================================
# Sub2API Docker Compose Configuration
# Sub2API Docker Compose Test Configuration (Local Build)
# =============================================================================
# Quick Start:
# 1. Copy .env.example to .env and configure
# 2. docker-compose up -d
# 3. Check logs: docker-compose logs -f sub2api
# 2. docker-compose -f docker-compose-test.yml up -d --build
# 3. Check logs: docker-compose -f docker-compose-test.yml logs -f sub2api
# 4. Access: http://localhost:8080
#
# This configuration builds the image from source (Dockerfile in project root).
# All configuration is done via environment variables.
# No Setup Wizard needed - the system auto-initializes on first run.
# =============================================================================
@@ -17,6 +18,9 @@ services:
# ===========================================================================
sub2api:
image: sub2api:latest
build:
context: ..
dockerfile: Dockerfile
container_name: sub2api
restart: unless-stopped
ulimits:
@@ -28,6 +32,8 @@ services:
volumes:
# Data persistence (config.yaml will be auto-generated here)
- sub2api_data:/app/data
# Mount custom config.yaml (optional, overrides auto-generated config)
- ./config.yaml:/app/data/config.yaml:ro
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)
@@ -91,6 +97,12 @@ services:
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
# Allow private IP addresses for CRS sync (for internal deployments)
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
depends_on:
postgres:
condition: service_healthy

View File

@@ -28,6 +28,9 @@ services:
volumes:
# Data persistence (config.yaml will be auto-generated here)
- sub2api_data:/app/data
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml:ro
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)
@@ -69,7 +72,10 @@ services:
# =======================================================================
# JWT Configuration
# =======================================================================
# Leave empty to auto-generate (recommended)
# IMPORTANT: Set a fixed JWT_SECRET to prevent login sessions from being
# invalidated after container restarts. If left empty, a random secret
# will be generated on each startup.
# Generate a secure secret: openssl rand -hex 32
- JWT_SECRET=${JWT_SECRET:-}
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
@@ -91,6 +97,13 @@ services:
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
- SECURITY_URL_ALLOWLIST_UPSTREAM_HOSTS=${SECURITY_URL_ALLOWLIST_UPSTREAM_HOSTS:-}
# Allow private IP addresses for CRS sync (for internal deployments)
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-false}
depends_on:
postgres:
condition: service_healthy
@@ -138,7 +151,7 @@ services:
# Redis Cache
# ===========================================================================
redis:
image: redis:7-alpine
image: redis:8-alpine
container_name: sub2api-redis
restart: unless-stopped
ulimits:

36
frontend/.eslintrc.cjs Normal file
View File

@@ -0,0 +1,36 @@
module.exports = {
root: true,
env: {
browser: true,
es2021: true,
node: true,
},
parser: "vue-eslint-parser",
parserOptions: {
parser: "@typescript-eslint/parser",
ecmaVersion: "latest",
sourceType: "module",
extraFileExtensions: [".vue"],
},
plugins: ["vue", "@typescript-eslint"],
extends: [
"eslint:recommended",
"plugin:vue/vue3-essential",
"plugin:@typescript-eslint/recommended",
],
rules: {
"no-constant-condition": "off",
"no-mixed-spaces-and-tabs": "off",
"no-useless-escape": "off",
"no-unused-vars": "off",
"@typescript-eslint/no-unused-vars": [
"warn",
{ argsIgnorePattern: "^_", varsIgnorePattern: "^_" },
],
"@typescript-eslint/ban-types": "off",
"@typescript-eslint/ban-ts-comment": "off",
"@typescript-eslint/no-explicit-any": "off",
"vue/multi-word-component-names": "off",
"vue/no-use-v-if-with-v-for": "off",
},
};

4
frontend/.npmrc Normal file
View File

@@ -0,0 +1,4 @@
legacy-peer-deps=true
# 允许运行所有包的构建脚本
# esbuild 和 vue-demi 是已知安全的包,需要 postinstall 脚本才能正常工作
ignore-scripts=false

10784
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,7 @@
"build": "vue-tsc -b && vite build",
"preview": "vite preview",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix",
"lint:check": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts",
"typecheck": "vue-tsc --noEmit"
},
"dependencies": {
@@ -30,6 +31,10 @@
"@types/node": "^20.10.5",
"@vitejs/plugin-vue": "^5.2.3",
"autoprefixer": "^10.4.16",
"@typescript-eslint/eslint-plugin": "^7.18.0",
"@typescript-eslint/parser": "^7.18.0",
"eslint": "^8.57.0",
"eslint-plugin-vue": "^9.25.0",
"postcss": "^8.4.32",
"tailwindcss": "^3.4.0",
"typescript": "~5.6.0",

6384
frontend/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,14 +26,42 @@ export interface SystemSettings {
smtp_host: string
smtp_port: number
smtp_username: string
smtp_password: string
smtp_password_configured: boolean
smtp_from_email: string
smtp_from_name: string
smtp_use_tls: boolean
// Cloudflare Turnstile settings
turnstile_enabled: boolean
turnstile_site_key: string
turnstile_secret_key: string
turnstile_secret_key_configured: boolean
// Identity patch configuration (Claude -> Gemini)
enable_identity_patch: boolean
identity_patch_prompt: string
}
export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
default_balance?: number
default_concurrency?: number
site_name?: string
site_logo?: string
site_subtitle?: string
api_base_url?: string
contact_info?: string
doc_url?: string
smtp_host?: string
smtp_port?: number
smtp_username?: string
smtp_password?: string
smtp_from_email?: string
smtp_from_name?: string
smtp_use_tls?: boolean
turnstile_enabled?: boolean
turnstile_site_key?: string
turnstile_secret_key?: string
enable_identity_patch?: boolean
identity_patch_prompt?: string
}
/**
@@ -50,7 +78,7 @@ export async function getSettings(): Promise<SystemSettings> {
* @param settings - Partial settings to update
* @returns Updated settings
*/
export async function updateSettings(settings: Partial<SystemSettings>): Promise<SystemSettings> {
export async function updateSettings(settings: UpdateSettingsRequest): Promise<SystemSettings> {
const { data } = await apiClient.put<SystemSettings>('/admin/settings', settings)
return data
}

View File

@@ -5,6 +5,7 @@
import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios'
import type { ApiResponse } from '@/types'
import { getLocale } from '@/i18n'
// ==================== Axios Instance Configuration ====================
@@ -27,6 +28,12 @@ apiClient.interceptors.request.use(
if (token && config.headers) {
config.headers.Authorization = `Bearer ${token}`
}
// Attach locale for backend translations
if (config.headers) {
config.headers['Accept-Language'] = getLocale()
}
return config
},
(error) => {
@@ -62,8 +69,24 @@ apiClient.interceptors.response.use(
// 401: Unauthorized - clear token and redirect to login
if (status === 401) {
const hasToken = !!localStorage.getItem('auth_token')
const url = error.config?.url || ''
const isAuthEndpoint =
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
const headers = error.config?.headers as Record<string, unknown> | undefined
const authHeader = headers?.Authorization ?? headers?.authorization
const sentAuth =
typeof authHeader === 'string'
? authHeader.trim() !== ''
: Array.isArray(authHeader)
? authHeader.length > 0
: !!authHeader
localStorage.removeItem('auth_token')
localStorage.removeItem('auth_user')
if ((hasToken || sentAuth) && !isAuthEndpoint) {
sessionStorage.setItem('auth_expired', '1')
}
// Only redirect if not already on login page
if (!window.location.pathname.includes('/login')) {
window.location.href = '/login'

View File

@@ -15,14 +15,7 @@
<div
class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-primary-500 to-primary-600"
>
<svg class="h-5 w-5 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
<Icon name="chartBar" size="md" class="text-white" :stroke-width="2" />
</div>
<div>
<div class="font-semibold text-gray-900 dark:text-gray-100">{{ account.name }}</div>
@@ -97,19 +90,7 @@
t('admin.accounts.stats.totalRequests')
}}</span>
<div class="rounded-lg bg-blue-100 p-1.5 dark:bg-blue-900/30">
<svg
class="h-4 w-4 text-blue-600 dark:text-blue-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 10V3L4 14h7v7l9-11h-7z"
/>
</svg>
<Icon name="bolt" size="sm" class="text-blue-600 dark:text-blue-400" :stroke-width="2" />
</div>
</div>
<p class="text-2xl font-bold text-gray-900 dark:text-white">
@@ -129,19 +110,12 @@
t('admin.accounts.stats.avgDailyCost')
}}</span>
<div class="rounded-lg bg-amber-100 p-1.5 dark:bg-amber-900/30">
<svg
class="h-4 w-4 text-amber-600 dark:text-amber-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 7h6m0 10v-3m-3 3h.01M9 17h.01M9 14h.01M12 14h.01M15 11h.01M12 11h.01M9 11h.01M7 21h10a2 2 0 002-2V5a2 2 0 00-2-2H7a2 2 0 00-2 2v14a2 2 0 002 2z"
/>
</svg>
<Icon
name="calculator"
size="sm"
class="text-amber-600 dark:text-amber-400"
:stroke-width="2"
/>
</div>
</div>
<p class="text-2xl font-bold text-gray-900 dark:text-white">
@@ -245,19 +219,12 @@
<div class="card p-4">
<div class="mb-3 flex items-center gap-2">
<div class="rounded-lg bg-orange-100 p-1.5 dark:bg-orange-900/30">
<svg
class="h-4 w-4 text-orange-600 dark:text-orange-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M17.657 18.657A8 8 0 016.343 7.343S7 9 9 10c0-2 .5-5 2.986-7C14 5 16.09 5.777 17.656 7.343A7.975 7.975 0 0120 13a7.975 7.975 0 01-2.343 5.657z"
/>
</svg>
<Icon
name="fire"
size="sm"
class="text-orange-600 dark:text-orange-400"
:stroke-width="2"
/>
</div>
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{
t('admin.accounts.stats.highestCostDay')
@@ -295,19 +262,12 @@
<div class="card p-4">
<div class="mb-3 flex items-center gap-2">
<div class="rounded-lg bg-indigo-100 p-1.5 dark:bg-indigo-900/30">
<svg
class="h-4 w-4 text-indigo-600 dark:text-indigo-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 7h8m0 0v8m0-8l-8 8-4-4-6 6"
/>
</svg>
<Icon
name="trendingUp"
size="sm"
class="text-indigo-600 dark:text-indigo-400"
:stroke-width="2"
/>
</div>
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{
t('admin.accounts.stats.highestRequestDay')
@@ -348,19 +308,7 @@
<div class="card p-4">
<div class="mb-3 flex items-center gap-2">
<div class="rounded-lg bg-teal-100 p-1.5 dark:bg-teal-900/30">
<svg
class="h-4 w-4 text-teal-600 dark:text-teal-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M20 7l-8-4-8 4m16 0l-8 4m8-4v10l-8 4m0-10L4 7m8 4v10M4 7v10l8 4"
/>
</svg>
<Icon name="cube" size="sm" class="text-teal-600 dark:text-teal-400" :stroke-width="2" />
</div>
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{
t('admin.accounts.stats.accumulatedTokens')
@@ -390,19 +338,7 @@
<div class="card p-4">
<div class="mb-3 flex items-center gap-2">
<div class="rounded-lg bg-rose-100 p-1.5 dark:bg-rose-900/30">
<svg
class="h-4 w-4 text-rose-600 dark:text-rose-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 10V3L4 14h7v7l9-11h-7z"
/>
</svg>
<Icon name="bolt" size="sm" class="text-rose-600 dark:text-rose-400" :stroke-width="2" />
</div>
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{
t('admin.accounts.stats.performance')
@@ -432,19 +368,12 @@
<div class="card p-4">
<div class="mb-3 flex items-center gap-2">
<div class="rounded-lg bg-lime-100 p-1.5 dark:bg-lime-900/30">
<svg
class="h-4 w-4 text-lime-600 dark:text-lime-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2"
/>
</svg>
<Icon
name="clipboard"
size="sm"
class="text-lime-600 dark:text-lime-400"
:stroke-width="2"
/>
</div>
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{
t('admin.accounts.stats.recentActivity')
@@ -504,14 +433,7 @@
v-else-if="!loading"
class="flex flex-col items-center justify-center py-12 text-gray-500 dark:text-gray-400"
>
<svg class="mb-4 h-12 w-12" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="1.5"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
<Icon name="chartBar" size="xl" class="mb-4 h-12 w-12" :stroke-width="1.5" />
<p class="text-sm">{{ t('admin.accounts.stats.noData') }}</p>
</div>
</div>
@@ -547,6 +469,7 @@ import { Line } from 'vue-chartjs'
import BaseDialog from '@/components/common/BaseDialog.vue'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
import Icon from '@/components/icons/Icon.vue'
import { adminAPI } from '@/api/admin'
import type { Account, AccountUsageStatsResponse } from '@/types'

View File

@@ -5,7 +5,7 @@
v-if="isTempUnschedulable"
type="button"
:class="['badge text-xs', statusClass, 'cursor-pointer']"
:title="t('admin.accounts.tempUnschedulable.viewDetails')"
:title="t('admin.accounts.status.viewTempUnschedDetails')"
@click="handleTempUnschedClick"
>
{{ statusText }}
@@ -48,20 +48,14 @@
<span
class="inline-flex items-center gap-1 rounded bg-amber-100 px-1.5 py-0.5 text-xs font-medium text-amber-700 dark:bg-amber-900/30 dark:text-amber-400"
>
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
429
</span>
<!-- Tooltip -->
<div
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
>
Rate limited until {{ formatTime(account.rate_limit_reset_at) }}
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatTime(account.rate_limit_reset_at) }) }}
<div
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
></div>
@@ -73,20 +67,14 @@
<span
class="inline-flex items-center gap-1 rounded bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400"
>
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
529
</span>
<!-- Tooltip -->
<div
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
>
Overloaded until {{ formatTime(account.overload_until) }}
{{ t('admin.accounts.status.overloadedUntil', { time: formatTime(account.overload_until) }) }}
<div
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
></div>
@@ -100,6 +88,7 @@ import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import type { Account } from '@/types'
import { formatTime } from '@/utils/format'
import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n()
@@ -160,7 +149,7 @@ const statusClass = computed(() => {
// Computed: status text
const statusText = computed(() => {
if (hasError.value) {
return t('common.error')
return t('admin.accounts.status.error')
}
if (isTempUnschedulable.value) {
return t('admin.accounts.status.tempUnschedulable')
@@ -171,7 +160,7 @@ const statusText = computed(() => {
if (isRateLimited.value || isOverloaded.value) {
return t('admin.accounts.status.limited')
}
return t(`common.${props.account.status}`)
return t(`admin.accounts.status.${props.account.status}`)
})
const handleTempUnschedClick = () => {

View File

@@ -15,14 +15,7 @@
<div
class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-primary-500 to-primary-600"
>
<svg class="h-5 w-5 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M5.121 17.804A13.937 13.937 0 0112 16c2.5 0 4.847.655 6.879 1.804M15 10a3 3 0 11-6 0 3 3 0 016 0zm6 2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<Icon name="userCircle" size="md" class="text-white" :stroke-width="2" />
</div>
<div>
<div class="font-semibold text-gray-900 dark:text-gray-100">{{ account.name }}</div>
@@ -48,21 +41,18 @@
</span>
</div>
<!-- Model Selection -->
<div class="space-y-1.5">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.selectTestModel') }}
</label>
<select
<Select
v-model="selectedModelId"
:options="availableModels"
:disabled="loadingModels || status === 'connecting'"
class="w-full rounded-lg border border-gray-300 bg-white px-3 py-2 text-sm text-gray-900 focus:border-primary-500 focus:ring-2 focus:ring-primary-500 disabled:cursor-not-allowed disabled:opacity-50 dark:border-dark-500 dark:bg-dark-700 dark:text-gray-100"
>
<option v-if="loadingModels" value="">{{ t('common.loading') }}...</option>
<option v-for="model in availableModels" :key="model.id" :value="model.id">
{{ model.display_name }} ({{ model.id }})
</option>
</select>
value-key="id"
label-key="display_name"
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
/>
</div>
<!-- Terminal Output -->
@@ -73,14 +63,7 @@
>
<!-- Status Line -->
<div v-if="status === 'idle'" class="flex items-center gap-2 text-gray-500">
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 10V3L4 14h7v7l9-11h-7z"
/>
</svg>
<Icon name="bolt" size="sm" :stroke-width="2" />
<span>{{ t('admin.accounts.readyToTest') }}</span>
</div>
<div v-else-if="status === 'connecting'" class="flex items-center gap-2 text-yellow-400">
@@ -131,14 +114,7 @@
v-else-if="status === 'error'"
class="mt-3 flex items-center gap-2 border-t border-gray-700 pt-3 text-red-400"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<Icon name="xCircle" size="sm" :stroke-width="2" />
<span>{{ errorMessage }}</span>
</div>
</div>
@@ -150,14 +126,7 @@
class="absolute right-2 top-2 rounded-lg bg-gray-800/80 p-1.5 text-gray-400 opacity-0 transition-all hover:bg-gray-700 hover:text-white group-hover:opacity-100"
:title="t('admin.accounts.copyOutput')"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
</svg>
<Icon name="copy" size="sm" :stroke-width="2" />
</button>
</div>
@@ -165,26 +134,12 @@
<div class="flex items-center justify-between px-1 text-xs text-gray-500 dark:text-gray-400">
<div class="flex items-center gap-3">
<span class="flex items-center gap-1">
<svg class="h-3.5 w-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 3v2m6-2v2M9 19v2m6-2v2M5 9H3m2 6H3m18-6h-2m2 6h-2M7 19h10a2 2 0 002-2V7a2 2 0 00-2-2H7a2 2 0 00-2 2v10a2 2 0 002 2zM9 9h6v6H9V9z"
/>
</svg>
<Icon name="cpu" size="sm" :stroke-width="2" />
{{ t('admin.accounts.testModel') }}
</span>
</div>
<span class="flex items-center gap-1">
<svg class="h-3.5 w-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M8 10h.01M12 10h.01M16 10h.01M9 16H5a2 2 0 01-2-2V6a2 2 0 012-2h14a2 2 0 012 2v8a2 2 0 01-2 2h-5l-5 5v-5z"
/>
</svg>
<Icon name="chatBubble" size="sm" :stroke-width="2" />
{{ t('admin.accounts.testPrompt') }}
</span>
</div>
@@ -280,6 +235,8 @@
import { ref, watch, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin'
import type { Account, ClaudeModel } from '@/types'

View File

@@ -318,19 +318,7 @@
<div v-if="enableCustomErrorCodes" id="bulk-edit-custom-error-codes-body" class="space-y-3">
<div class="rounded-lg bg-amber-50 p-3 dark:bg-amber-900/20">
<p class="text-xs text-amber-700 dark:text-amber-400">
<svg
class="mr-1 inline h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<Icon name="exclamationTriangle" size="sm" class="mr-1 inline" :stroke-width="2" />
{{ t('admin.accounts.customErrorCodesWarning') }}
</p>
</div>
@@ -391,14 +379,7 @@
class="hover:text-red-900 dark:hover:text-red-300"
@click="removeErrorCode(code)"
>
<svg class="h-3.5 w-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
<Icon name="x" size="xs" class="h-3.5 w-3.5" :stroke-width="2" />
</button>
</span>
<span v-if="selectedErrorCodes.length === 0" class="text-xs text-gray-400">
@@ -642,6 +623,7 @@ import BaseDialog from '@/components/common/BaseDialog.vue'
import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import Icon from '@/components/icons/Icon.vue'
interface Props {
show: boolean
@@ -849,7 +831,8 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
let credentialsChanged = false
if (enableProxy.value) {
updates.proxy_id = proxyId.value
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
updates.proxy_id = proxyId.value === null ? 0 : proxyId.value
}
if (enableConcurrency.value) {

Some files were not shown because too many files have changed in this diff Show More