Compare commits

..

10 Commits

Author SHA1 Message Date
IanShaw
afcfbb458d fix(gemini): Google One 强制使用内置 OAuth client + 自动获取 project_id + UI 优化 (#212)
* fix(gemini): Google One 强制使用内置 OAuth client + 自动获取 project_id + UI 优化

## 后端改动

### 1. Google One 强制使用内置 Gemini CLI OAuth Client
**问题**:
- Google One 之前允许使用自定义 OAuth client,导致认证流程不稳定
- 与 Code Assist 的行为不一致

**解决方案**:
- 修改 `gemini_oauth_service.go`: Google One 现在与 Code Assist 一样强制使用内置 client (L122-135)
- 更新 `gemini_oauth_client.go`: ExchangeCode 和 RefreshToken 方法支持强制内置 client (L31-44, L77-86)
- 简化 `geminicli/oauth.go`: Google One scope 选择逻辑 (L187-190)
- 标记 `geminicli/constants.go`: DefaultGoogleOneScopes 为 DEPRECATED (L30-33)
- 更新测试用例以反映新行为

**OAuth 类型对比**:
| OAuth类型 | Client来源 | Scopes | Redirect URI |
|-----------|-----------|--------|-----------------|
| code_assist | 内置 Gemini CLI | DefaultCodeAssistScopes | https://codeassist.google.com/authcode |
| google_one | 内置 Gemini CLI (新) | DefaultCodeAssistScopes | https://codeassist.google.com/authcode |
| ai_studio | 必须自定义 | DefaultAIStudioScopes | http://localhost:1455/auth/callback |

### 2. Google One 自动获取 project_id
**问题**:
- Google One 个人账号测试模型时返回 403/404 错误
- 原因:cloudaicompanion API 需要 project_id,但个人账号无需手动创建 GCP 项目

**解决方案**:
- 修改 `gemini_oauth_service.go`: OAuth 流程中自动调用 fetchProjectID
- Google 通过 LoadCodeAssist API 自动分配 project_id
- 与 Gemini CLI 行为保持一致
- 后端根据 project_id 自动选择正确的 API 端点

**影响**:
- Google One 账号现在可以正常使用(需要重新授权)
- Code Assist 和 AI Studio 账号不受影响

### 3. 修复 Gemini 测试账号无内容输出问题
**问题**:
- 测试 Gemini 账号时只显示"测试成功",没有显示 AI 响应内容
- 原因:processGeminiStream 在检查到 finishReason 时立即返回,跳过了内容提取

**解决方案**:
- 修改 `account_test_service.go`: 调整逻辑顺序,先提取内容再检查是否完成
- 确保最后一个 chunk 的内容也能被正确显示

**影响**:
- 所有 Gemini 账号类型(API Key、OAuth)的测试现在都会显示完整响应内容
- 用户可以看到流式输出效果

## 前端改动

### 1. 修复图标宽度压缩问题
**问题**:
- 账户类型选择按钮中的图标在某些情况下会被压缩变形

**解决方案**:
- 修改 `CreateAccountModal.vue`: 为所有平台图标容器添加 `shrink-0` 类
- 确保 Anthropic、OpenAI、Gemini、Antigravity 图标保持固定 8×8 尺寸 (32px × 32px)

### 2. 优化重新授权界面
**问题**:
- 重新授权时显示三个可点击的授权类型选择按钮,可能导致用户误切换到不兼容的授权方式

**解决方案**:
- 修改 `ReAuthAccountModal.vue` (admin 和普通用户版本):
  - 将可点击的授权类型选择按钮改为只读信息展示框
  - 根据账号的 `credentials.oauth_type` 动态显示对应图标和文本
  - 删除 `geminiAIStudioOAuthEnabled` 状态和 `handleSelectGeminiOAuthType` 方法
  - 防止用户误操作

## 测试验证
-  所有后端单元测试通过
-  OAuth client 选择逻辑正确
-  Google One 和 Code Assist 行为一致
-  测试账号显示完整响应内容
-  UI 图标显示正常
-  重新授权界面只读展示正确

* fix(lint): 修复 golangci-lint 错误信息格式问题

- 将错误信息改为小写开头以符合 Go 代码规范
- 修复 ST1005: error strings should not be capitalized
2026-01-08 23:47:29 +08:00
Edric Li
8f24d239af fix: update integration tests for GatewayCache groupID parameter 2026-01-08 23:25:05 +08:00
Edric Li
b7a29a4bac fix: update mock interfaces and fix gofmt issues for CI
- Update mockGatewayCacheForPlatform and mockGatewayCacheForGemini
  to match new GatewayCache interface with groupID parameter
- Fix gofmt formatting in group_handler.go and admin_service.go
2026-01-08 23:13:57 +08:00
Edric Li
a42105881f feat(groups): add Claude Code client restriction and session isolation
- Add claude_code_only field to restrict groups to Claude Code clients only
- Add fallback_group_id for non-Claude Code requests to use alternate group
- Implement ClaudeCodeValidator for User-Agent detection
- Add group-level session binding isolation (groupID in Redis key)
- Prevent cross-group sticky session pollution
- Update frontend with Claude Code restriction controls
2026-01-08 23:07:00 +08:00
Edric Li
958ffe7a8a test: fix unit tests for user_agent and proxy repo interface
- Add user_agent field to API contract test expectation
- Add ListWithFiltersAndAccountCount stub to proxyRepoStub
2026-01-08 21:44:18 +08:00
Edric Li
b46b3c5c3c Merge remote-tracking branch 'upstream/main' 2026-01-08 21:35:34 +08:00
Edric Li
fd1b14fd1d feat(home): redirect admin users to admin dashboard
When clicking "Enter Console" button on home page, admin users are now
redirected to /admin/dashboard instead of /dashboard.
2026-01-08 21:26:00 +08:00
Edric Li
eb198e5969 feat(proxies): add account count column to proxy list
Display the number of accounts bound to each proxy in the admin proxy
management page, similar to the groups list view.
2026-01-08 21:20:12 +08:00
Edric Li
70fcbd7006 feat(usage): add User-Agent column to usage logs
- Add user_agent field to UsageLog DTO and mapper
- Display User-Agent column in admin and user usage tables
- Add formatUserAgent helper to show friendly client names
- Include user_agent in Excel export
- Remove request_id column from admin usage table
2026-01-08 21:02:13 +08:00
shaw
b015a3bd8a fix(antigravity): 修复频繁出现429错误的问题 2026-01-08 20:06:32 +08:00
56 changed files with 1962 additions and 447 deletions

View File

@@ -51,6 +51,10 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldIsExclusive:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays:
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
} else if value.Valid {
_m.ClaudeCodeOnly = value.Bool
}
case group.FieldFallbackGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
} else if value.Valid {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,6 +457,14 @@ func (_m *Group) String() string {
builder.WriteString("image_price_4k=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
if v := _m.FallbackGroupID; v != nil {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -49,6 +49,10 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -141,6 +145,8 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
}
var (
@@ -196,6 +202,8 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
}
// ByFallbackGroupID orders the results by the fallback_group_id field.
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
func FallbackGroupIDEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
func FallbackGroupIDNEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
func FallbackGroupIDIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
func FallbackGroupIDGT(v int64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
}
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
func FallbackGroupIDGTE(v int64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
func FallbackGroupIDLT(v int64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
}
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
func FallbackGroupIDLTE(v int64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
func FallbackGroupIDIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
}
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
return _c
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
if v != nil {
_c.SetClaudeCodeOnly(*v)
}
return _c
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
_c.mutation.SetFallbackGroupID(v)
return _c
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
if v != nil {
_c.SetFallbackGroupID(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
return nil
}
@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
return nil
}
@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
}
if value, ok := _c.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
return u
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
u.SetExcluded(group.FieldClaudeCodeOnly)
return u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
u.Set(group.FieldFallbackGroupID, v)
return u
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
u.SetExcluded(group.FieldFallbackGroupID)
return u
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
u.Add(group.FieldFallbackGroupID, v)
return u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
u.SetNull(group.FieldFallbackGroupID)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -221,6 +221,8 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{

View File

@@ -3590,6 +3590,9 @@ type GroupMutation struct {
addimage_price_2k *float64
image_price_4k *float64
addimage_price_4k *float64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k)
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
}
// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
v := m.claude_code_only
if v == nil {
return
}
return *v, true
}
// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
}
return oldValue.ClaudeCodeOnly, nil
}
// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
func (m *GroupMutation) ResetClaudeCodeOnly() {
m.claude_code_only = nil
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (m *GroupMutation) SetFallbackGroupID(i int64) {
m.fallback_group_id = &i
m.addfallback_group_id = nil
}
// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
v := m.fallback_group_id
if v == nil {
return
}
return *v, true
}
// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
}
return oldValue.FallbackGroupID, nil
}
// AddFallbackGroupID adds i to the "fallback_group_id" field.
func (m *GroupMutation) AddFallbackGroupID(i int64) {
if m.addfallback_group_id != nil {
*m.addfallback_group_id += i
} else {
m.addfallback_group_id = &i
}
}
// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
v := m.addfallback_group_id
if v == nil {
return
}
return *v, true
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (m *GroupMutation) ClearFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
}
// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
func (m *GroupMutation) FallbackGroupIDCleared() bool {
_, ok := m.clearedFields[group.FieldFallbackGroupID]
return ok
}
// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
func (m *GroupMutation) ResetFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
delete(m.clearedFields, group.FieldFallbackGroupID)
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 17)
fields := make([]string, 0, 19)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k()
case group.FieldImagePrice4k:
return m.ImagePrice4k()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
return m.FallbackGroupID()
}
return nil, false
}
@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetImagePrice4k(v)
return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetClaudeCodeOnly(v)
return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetFallbackGroupID(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k()
case group.FieldImagePrice4k:
return m.AddedImagePrice4k()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
}
return nil, false
}
@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddImagePrice4k(v)
return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddFallbackGroupID(v)
return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k)
}
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields
}
@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k:
m.ClearImagePrice4k()
return nil
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
}
return fmt.Errorf("unknown Group nullable field %s", name)
}
@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k:
m.ResetImagePrice4k()
return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
case group.FieldFallbackGroupID:
m.ResetFallbackGroupID()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -270,6 +270,10 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
proxyMixin := schema.Proxy{}.Mixin()
proxyMixinHooks1 := proxyMixin[1].Hooks()
proxy.Hooks[0] = proxyMixinHooks1[0]

View File

@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("是否仅允许 Claude Code 客户端"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
}
}
@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type).
Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组M2O 关系)
}
}

View File

@@ -34,9 +34,11 @@ type CreateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
}
// UpdateGroupRequest represents update group request
@@ -52,9 +54,11 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
}
// List handles listing all groups with pagination
@@ -150,6 +154,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -188,6 +194,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -52,15 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
status := c.Query("status")
search := c.Query("search")
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Proxy, 0, len(proxies))
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
@@ -280,6 +282,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
UserAgent: l.UserAgent,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),

View File

@@ -52,6 +52,10 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -180,6 +184,9 @@ type UsageLog struct {
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
// User-Agent
UserAgent *string `json:"user_agent"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`

View File

@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")

View File

@@ -2,6 +2,7 @@ package handler
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
@@ -13,6 +14,26 @@ import (
"github.com/gin-gonic/gin"
)
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
// 验证是否为 Claude Code 客户端
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
c.Request = c.Request.WithContext(ctx)
}
// 并发槽位等待相关常量
//
// 性能优化说明:

View File

@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}

View File

@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}

View File

@@ -149,6 +149,11 @@ func defaultIdentityPatch(_ string) string {
return antigravityIdentity
}
// GetDefaultIdentityPatch 返回默认的 Antigravity 身份提示词
func GetDefaultIdentityPatch() string {
return antigravityIdentity
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart

View File

@@ -7,4 +7,6 @@ type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
)

View File

@@ -27,10 +27,9 @@ const (
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// DefaultScopes for Google One (personal Google accounts with Gemini access)
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
// DefaultGoogleOneScopes (DEPRECATED, no longer used)
// Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
// This constant is kept for backward compatibility but is not actively used.
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.

View File

@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
effective.Scopes = DefaultAIStudioScopes
}
case "google_one":
// Google One uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultGoogleOneScopes
}
// Google One always uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
effective.Scopes = DefaultCodeAssistScopes
default:
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes

View File

@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Google One with custom client",
name: "Google One always uses built-in client (even if custom credentials passed)",
input: OAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantScopes: DefaultGoogleOneScopes,
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantErr: false,
},
{

View File

@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
// buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}

View File

@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}

View File

@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx)
if err == nil {
@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}

View File

@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).

View File

@@ -243,7 +243,8 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"created_at": "2025-01-02T03:04:05Z"
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
],
"total": 1,

View File

@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
// Extract content first (before checking completion)
if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
}
}
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
}
}

View File

@@ -47,6 +47,7 @@ type AdminService interface {
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
@@ -99,9 +100,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
}
type UpdateGroupInput struct {
@@ -116,9 +119,11 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
}
type CreateAccountInput struct {
@@ -515,6 +520,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
// 校验降级分组
if input.FallbackGroupID != nil {
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
return nil, err
}
}
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -529,6 +541,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -552,6 +566,29 @@ func normalizePrice(price *float64) *float64 {
return price
}
// validateFallbackGroup 校验降级分组的有效性
// currentGroupID: 当前分组 ID新建时为 0
// fallbackGroupID: 降级分组 ID
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
// 不能将自己设置为降级分组
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
@@ -602,6 +639,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
}
if input.FallbackGroupID != nil {
// 校验降级分组
if *input.FallbackGroupID > 0 {
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
return nil, err
}
group.FallbackGroupID = input.FallbackGroupID
} else {
// 传入 0 或负数表示清除降级分组
group.FallbackGroupID = nil
}
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
@@ -950,6 +1004,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}

View File

@@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy
panic("unexpected ListActiveWithAccountCount call")
}
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("unexpected ListWithFiltersAndAccountCount call")
}
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("unexpected ExistsByHostPortAuth call")
}

View File

@@ -187,10 +187,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, err
}
// DEBUG: 打印请求 header 和 body
log.Printf("[DEBUG] Antigravity TestConnection - URL: %s", req.URL.String())
log.Printf("[DEBUG] Antigravity TestConnection - Headers: %v", req.Header)
log.Printf("[DEBUG] Antigravity TestConnection - Body: %s", string(requestBody))
// 调试日志Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 代理 URL
proxyURL := ""
@@ -235,6 +233,12 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
},
},
},
// Antigravity 上游要求必须包含身份提示词
"systemInstruction": map[string]any{
"parts": []map[string]any{
{"text": antigravity.GetDefaultIdentityPatch()},
},
},
}
payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
@@ -333,6 +337,53 @@ func extractTextFromSSEResponse(respBody []byte) string {
return strings.Join(texts, "")
}
// injectIdentityPatchToGeminiRequest 为 Gemini 格式请求注入身份提示词
// 如果请求中已包含 "You are Antigravity" 则不重复注入
func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
var request map[string]any
if err := json.Unmarshal(body, &request); err != nil {
return nil, fmt.Errorf("解析 Gemini 请求失败: %w", err)
}
// 检查现有 systemInstruction 是否已包含身份提示词
if sysInst, ok := request["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
if strings.Contains(text, "You are Antigravity") {
// 已包含身份提示词,直接返回原始请求
return body, nil
}
}
}
}
}
}
// 获取默认身份提示词
identityPatch := antigravity.GetDefaultIdentityPatch()
// 构建新的 systemInstruction
newPart := map[string]any{"text": identityPatch}
if existing, ok := request["systemInstruction"].(map[string]any); ok {
// 已有 systemInstruction在开头插入身份提示词
if parts, ok := existing["parts"].([]any); ok {
existing["parts"] = append([]any{newPart}, parts...)
} else {
existing["parts"] = []any{newPart}
}
} else {
// 没有 systemInstruction创建新的
request["systemInstruction"] = map[string]any{
"parts": []any{newPart},
}
}
return json.Marshal(request)
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
@@ -418,17 +469,20 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL()
}
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
transformOpts.EnableIdentityPatch = true // 强制启用Antigravity 上游必需
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
}
// 构建上游 actionNewAPIRequest 会自动处理 ?alt=sse 和 Accept Header
action := "generateContent"
if claudeReq.Stream {
action = "streamGenerateContent"
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 重试循环
var resp *http.Response
@@ -465,7 +519,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
@@ -584,6 +638,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
@@ -592,10 +647,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usage, err = s.handleClaudeNonStreamingResponse(c, resp, originalModel)
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
return &ForwardResult{
@@ -928,18 +987,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL = account.Proxy.URL()
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
// Antigravity 上游要求必须包含身份提示词,注入到请求
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
if err != nil {
return nil, err
}
// 构建上游 actionNewAPIRequest 会自动处理 ?alt=sse 和 Accept Header
upstreamAction := action
if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent"
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
return nil, err
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
@@ -1016,7 +1079,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
@@ -1066,7 +1129,8 @@ handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
if stream || upstreamAction == "streamGenerateContent" {
if stream {
// 客户端要求流式,直接透传
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
@@ -1075,11 +1139,14 @@ handleSuccess:
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = usageResp
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
if usage == nil {
@@ -1126,9 +1193,9 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
delay := antigravityRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > antigravityRetryMaxDelay {
delay = antigravityRetryMaxDelay
}
// +/- 20% jitter
@@ -1340,25 +1407,150 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
type scanEvent struct {
line string
err error
}
// 解包 v1internal 响应
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
var parsed map[string]any
if json.Unmarshal(unwrapped, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
c.Data(resp.StatusCode, "application/json", unwrapped)
return u, nil
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
c.Data(resp.StatusCode, "application/json", unwrapped)
return &ClaudeUsage{}, nil
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,返回收集的响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 提取 usage
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
respBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
@@ -1435,17 +1627,148 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
return fmt.Errorf("%s", message)
}
// handleClaudeNonStreamingResponse 处理 Claude 非流式响应Gemini → Claude 转换)
func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,转换并返回响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 序列化为 JSONGemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
}
// 转换 Gemini 响应为 Claude 格式
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel)
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
if err != nil {
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(body))
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
@@ -1458,7 +1781,8 @@ func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Cont
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return usage, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)

View File

@@ -0,0 +1,265 @@
package service
import (
"context"
"net/http"
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
// 完全学习自 claude-relay-service 项目的验证逻辑
type ClaudeCodeValidator struct{}
var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
// Claude Code 官方 System Prompt 模板
// 从 claude-relay-service/src/utils/contents.js 提取
var claudeCodeSystemPrompts = []string{
// claudeOtherSystemPrompt1 - Primary
"You are Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPrompt3 - Agent SDK
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
// claudeOtherSystemPrompt4 - Compact Agent SDK
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
// exploreAgentSystemPrompt
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
"You are a helpful AI assistant tasked with summarizing conversations.",
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
"You are an interactive CLI tool that helps users",
}
// NewClaudeCodeValidator 创建验证器实例
func NewClaudeCodeValidator() *ClaudeCodeValidator {
return &ClaudeCodeValidator{}
}
// Validate 验证请求是否来自 Claude Code CLI
// 采用与 claude-relay-service 完全一致的验证策略:
//
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
// Step 3: 对于 messages 路径,进行严格验证:
// - System prompt 相似度检查
// - X-App header 检查
// - anthropic-beta header 检查
// - anthropic-version header 检查
// - metadata.user_id 格式验证
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
// Step 1: User-Agent 检查
ua := r.Header.Get("User-Agent")
if !claudeCodeUAPattern.MatchString(ua) {
return false
}
// Step 2: 非 messages 路径,只要 UA 匹配就通过
path := r.URL.Path
if !strings.Contains(path, "messages") {
return true
}
// Step 3: messages 路径,进行严格验证
// 3.1 检查 system prompt 相似度
if !v.hasClaudeCodeSystemPrompt(body) {
return false
}
// 3.2 检查必需的 headers值不为空即可
xApp := r.Header.Get("X-App")
if xApp == "" {
return false
}
anthropicBeta := r.Header.Get("anthropic-beta")
if anthropicBeta == "" {
return false
}
anthropicVersion := r.Header.Get("anthropic-version")
if anthropicVersion == "" {
return false
}
// 3.3 验证 metadata.user_id
if body == nil {
return false
}
metadata, ok := body["metadata"].(map[string]any)
if !ok {
return false
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return false
}
if !userIDPattern.MatchString(userID) {
return false
}
return true
}
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
// 使用字符串相似度匹配Dice coefficient
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
if body == nil {
return false
}
// 检查 model 字段
if _, ok := body["model"].(string); !ok {
return false
}
// 获取 system 字段
systemEntries, ok := body["system"].([]any)
if !ok {
return false
}
// 检查每个 system entry
for _, entry := range systemEntries {
entryMap, ok := entry.(map[string]any)
if !ok {
continue
}
text, ok := entryMap["text"].(string)
if !ok || text == "" {
continue
}
// 计算与所有模板的最佳相似度
bestScore := v.bestSimilarityScore(text)
if bestScore >= systemPromptThreshold {
return true
}
}
return false
}
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
normalizedText := normalizePrompt(text)
bestScore := 0.0
for _, template := range claudeCodeSystemPrompts {
normalizedTemplate := normalizePrompt(template)
score := diceCoefficient(normalizedText, normalizedTemplate)
if score > bestScore {
bestScore = score
}
}
return bestScore
}
// normalizePrompt 标准化提示词文本(去除多余空白)
func normalizePrompt(text string) string {
// 将所有空白字符替换为单个空格,并去除首尾空白
return strings.Join(strings.Fields(text), " ")
}
// diceCoefficient 计算两个字符串的 Dice 系数SørensenDice coefficient
// 这是 string-similarity 库使用的算法
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
func diceCoefficient(a, b string) float64 {
if a == b {
return 1.0
}
if len(a) < 2 || len(b) < 2 {
return 0.0
}
// 生成 bigrams
bigramsA := getBigrams(a)
bigramsB := getBigrams(b)
if len(bigramsA) == 0 || len(bigramsB) == 0 {
return 0.0
}
// 计算交集大小
intersection := 0
for bigram, countA := range bigramsA {
if countB, exists := bigramsB[bigram]; exists {
if countA < countB {
intersection += countA
} else {
intersection += countB
}
}
}
// 计算总 bigram 数量
totalA := 0
for _, count := range bigramsA {
totalA += count
}
totalB := 0
for _, count := range bigramsB {
totalB += count
}
return float64(2*intersection) / float64(totalA+totalB)
}
// getBigrams 获取字符串的所有 bigrams相邻字符对
func getBigrams(s string) map[string]int {
bigrams := make(map[string]int)
runes := []rune(strings.ToLower(s))
for i := 0; i < len(runes)-1; i++ {
bigram := string(runes[i : i+2])
bigrams[bigram]++
}
return bigrams
}
// ValidateUserAgent 仅验证 User-Agent用于不需要解析请求体的场景
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
return claudeCodeUAPattern.MatchString(ua)
}
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
// 只要存在匹配的系统提示词就返回 true用于宽松检测
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
return v.hasClaudeCodeSystemPrompt(body)
}
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
func IsClaudeCodeClient(ctx context.Context) bool {
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
return v
}
return false
}
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}

View File

@@ -166,14 +166,14 @@ type mockGatewayCacheForPlatform struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
@@ -181,7 +181,7 @@ func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, s
return nil
}
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}

View File

@@ -56,6 +56,9 @@ var (
}
)
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
func derefGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
type AccountWaitPlan struct {
@@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
@@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
@@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
@@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountInGroup(account, groupID) &&
@@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -506,7 +539,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
@@ -556,7 +589,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
@@ -584,7 +617,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -592,7 +625,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
@@ -619,6 +652,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
if groupID == nil {
return groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
if !group.ClaudeCodeOnly {
return groupID, nil
}
// 分组启用了 Claude Code 限制
if IsClaudeCodeClient(ctx) {
return groupID, nil
}
// 非 Claude Code 客户端,检查降级分组
if group.FallbackGroupID != nil {
return group.FallbackGroupID, nil
}
return nil, ErrClaudeCodeOnly
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
@@ -738,13 +807,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
@@ -811,7 +880,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
@@ -827,14 +896,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
@@ -903,7 +972,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}

View File

@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
}
}
@@ -217,7 +217,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil

View File

@@ -190,14 +190,14 @@ type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
@@ -205,7 +205,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses
return nil
}
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}

View File

@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client.
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
}
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
}
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier
var storageInfo *geminicli.DriveStorageInfo

View File

@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
name: "google_one always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{

View File

@@ -22,6 +22,10 @@ type Group struct {
ImagePrice2K *float64
ImagePrice4K *float64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
}
// SelectAccount selects an OpenAI account with sticky session support
@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 4. Set sticky session
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
}
return selected, nil
@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
stickyAccountID = accountID
}
}
@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,

View File

@@ -20,6 +20,7 @@ type ProxyRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)

View File

@@ -0,0 +1,21 @@
-- 029_add_group_claude_code_restriction.sql
-- 添加分组级别的 Claude Code 客户端限制功能
-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE;
-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
-- 添加索引优化查询
CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only
ON groups(claude_code_only) WHERE deleted_at IS NULL;
CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id
ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL;
-- 添加字段注释
COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组';
COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID';

View File

@@ -166,7 +166,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based'
? 'bg-orange-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -196,7 +196,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -232,7 +232,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based'
? 'bg-green-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -258,7 +258,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -302,7 +302,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -332,7 +332,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -397,7 +397,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -440,7 +440,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -518,7 +518,7 @@
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'ai_studio'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
@@ -621,7 +621,7 @@
<div
class="flex items-center gap-3 rounded-lg border-2 border-purple-500 bg-purple-50 p-3 dark:bg-purple-900/20"
>
<div class="flex h-8 w-8 items-center justify-center rounded-lg bg-purple-500 text-white">
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-purple-500 text-white">
<Icon name="key" size="sm" />
</div>
<div>

View File

@@ -73,113 +73,48 @@
</div>
</fieldset>
<!-- Gemini OAuth Type Selection -->
<fieldset v-if="isGemini" class="border-0 p-0">
<legend class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</legend>
<div class="mt-2 grid grid-cols-3 gap-3">
<button
type="button"
@click="handleSelectGeminiOAuthType('google_one')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'google_one'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" />
</svg>
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">Google One</span>
<span class="text-xs text-gray-500 dark:text-gray-400">个人账号</span>
</div>
</button>
<button
type="button"
@click="handleSelectGeminiOAuthType('code_assist')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'code_assist'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.builtInTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.builtInDesc') }}
</span>
</div>
</button>
<button
type="button"
:disabled="!geminiAIStudioOAuthEnabled"
@click="handleSelectGeminiOAuthType('ai_studio')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
!geminiAIStudioOAuthEnabled ? 'cursor-not-allowed opacity-60' : '',
geminiOAuthType === 'ai_studio'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'ai_studio'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="sparkles" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.customTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.customDesc') }}
</span>
<div v-if="!geminiAIStudioOAuthEnabled" class="group relative mt-1 inline-block">
<span
class="rounded bg-amber-100 px-2 py-0.5 text-xs text-amber-700 dark:bg-amber-900/30 dark:text-amber-300"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredShort') }}
</span>
<div
class="pointer-events-none absolute left-0 top-full z-10 mt-2 w-[28rem] rounded-md border border-amber-200 bg-amber-50 px-3 py-2 text-xs text-amber-800 opacity-0 shadow-sm transition-opacity group-hover:opacity-100 dark:border-amber-700 dark:bg-amber-900/40 dark:text-amber-200"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredTip') }}
</div>
</div>
</div>
</button>
<!-- Gemini OAuth Type Display (read-only) -->
<div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
</div>
</fieldset>
<div class="flex items-center gap-3">
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white'
: 'bg-amber-500 text-white'
]"
>
<Icon v-if="geminiOAuthType === 'google_one'" name="user" size="sm" />
<Icon v-else-if="geminiOAuthType === 'code_assist'" name="cloud" size="sm" />
<Icon v-else name="sparkles" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{
geminiOAuthType === 'google_one'
? 'Google One'
: geminiOAuthType === 'code_assist'
? t('admin.accounts.gemini.oauthType.builtInTitle')
: t('admin.accounts.gemini.oauthType.customTitle')
}}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{
geminiOAuthType === 'google_one'
? '个人账号'
: geminiOAuthType === 'code_assist'
? t('admin.accounts.gemini.oauthType.builtInDesc')
: t('admin.accounts.gemini.oauthType.customDesc')
}}
</span>
</div>
</div>
</div>
<OAuthAuthorizationFlow
ref="oauthFlowRef"
@@ -299,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
// State
const addMethod = ref<AddMethod>('oauth')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
const geminiAIStudioOAuthEnabled = ref(false)
// Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai')
@@ -367,14 +301,6 @@ watch(
? 'ai_studio'
: 'code_assist'
}
if (isGemini.value) {
geminiOAuth.getCapabilities().then((caps) => {
geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled
if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') {
geminiOAuthType.value = 'code_assist'
}
})
}
} else {
resetState()
}
@@ -385,7 +311,6 @@ watch(
const resetState = () => {
addMethod.value = 'oauth'
geminiOAuthType.value = 'code_assist'
geminiAIStudioOAuthEnabled.value = false
claudeOAuth.resetState()
openaiOAuth.resetState()
geminiOAuth.resetState()
@@ -393,14 +318,6 @@ const resetState = () => {
oauthFlowRef.value?.reset()
}
const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => {
if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) {
appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured'))
return
}
geminiOAuthType.value = oauthType
}
const handleClose = () => {
emit('close')
}

View File

@@ -73,111 +73,48 @@
</div>
</fieldset>
<!-- Gemini OAuth Type Selection -->
<fieldset v-if="isGemini" class="border-0 p-0">
<legend class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</legend>
<div class="mt-2 grid grid-cols-3 gap-3">
<button
type="button"
@click="handleSelectGeminiOAuthType('google_one')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'google_one'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="user" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">Google One</span>
<span class="text-xs text-gray-500 dark:text-gray-400">个人账号</span>
</div>
</button>
<button
type="button"
@click="handleSelectGeminiOAuthType('code_assist')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'code_assist'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.builtInTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.builtInDesc') }}
</span>
</div>
</button>
<button
type="button"
:disabled="!geminiAIStudioOAuthEnabled"
@click="handleSelectGeminiOAuthType('ai_studio')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
!geminiAIStudioOAuthEnabled ? 'cursor-not-allowed opacity-60' : '',
geminiOAuthType === 'ai_studio'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'ai_studio'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="sparkles" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.customTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.customDesc') }}
</span>
<div v-if="!geminiAIStudioOAuthEnabled" class="group relative mt-1 inline-block">
<span
class="rounded bg-amber-100 px-2 py-0.5 text-xs text-amber-700 dark:bg-amber-900/30 dark:text-amber-300"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredShort') }}
</span>
<div
class="pointer-events-none absolute left-0 top-full z-10 mt-2 w-[28rem] rounded-md border border-amber-200 bg-amber-50 px-3 py-2 text-xs text-amber-800 opacity-0 shadow-sm transition-opacity group-hover:opacity-100 dark:border-amber-700 dark:bg-amber-900/40 dark:text-amber-200"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredTip') }}
</div>
</div>
</div>
</button>
<!-- Gemini OAuth Type Display (read-only) -->
<div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
</div>
</fieldset>
<div class="flex items-center gap-3">
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white'
: 'bg-amber-500 text-white'
]"
>
<Icon v-if="geminiOAuthType === 'google_one'" name="user" size="sm" />
<Icon v-else-if="geminiOAuthType === 'code_assist'" name="cloud" size="sm" />
<Icon v-else name="sparkles" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{
geminiOAuthType === 'google_one'
? 'Google One'
: geminiOAuthType === 'code_assist'
? t('admin.accounts.gemini.oauthType.builtInTitle')
: t('admin.accounts.gemini.oauthType.customTitle')
}}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{
geminiOAuthType === 'google_one'
? '个人账号'
: geminiOAuthType === 'code_assist'
? t('admin.accounts.gemini.oauthType.builtInDesc')
: t('admin.accounts.gemini.oauthType.customDesc')
}}
</span>
</div>
</div>
</div>
<OAuthAuthorizationFlow
ref="oauthFlowRef"
@@ -297,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
// State
const addMethod = ref<AddMethod>('oauth')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
const geminiAIStudioOAuthEnabled = ref(false)
// Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai')
@@ -365,14 +301,6 @@ watch(
? 'ai_studio'
: 'code_assist'
}
if (isGemini.value) {
geminiOAuth.getCapabilities().then((caps) => {
geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled
if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') {
geminiOAuthType.value = 'code_assist'
}
})
}
} else {
resetState()
}
@@ -383,7 +311,6 @@ watch(
const resetState = () => {
addMethod.value = 'oauth'
geminiOAuthType.value = 'code_assist'
geminiAIStudioOAuthEnabled.value = false
claudeOAuth.resetState()
openaiOAuth.resetState()
geminiOAuth.resetState()
@@ -391,14 +318,6 @@ const resetState = () => {
oauthFlowRef.value?.reset()
}
const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => {
if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) {
appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured'))
return
}
geminiOAuthType.value = oauthType
}
const handleClose = () => {
emit('close')
}

View File

@@ -115,15 +115,9 @@
<span class="text-sm text-gray-600 dark:text-gray-400">{{ formatDateTime(value) }}</span>
</template>
<template #cell-request_id="{ row }">
<div v-if="row.request_id" class="flex items-center gap-1.5 max-w-[120px]">
<span class="font-mono text-xs text-gray-500 dark:text-gray-400 truncate" :title="row.request_id">{{ row.request_id }}</span>
<button @click="copyRequestId(row.request_id)" class="flex-shrink-0 rounded p-0.5 transition-colors hover:bg-gray-100 dark:hover:bg-dark-700" :class="copiedRequestId === row.request_id ? 'text-green-500' : 'text-gray-400 hover:text-gray-600 dark:hover:text-gray-300'" :title="copiedRequestId === row.request_id ? t('keys.copied') : t('keys.copyToClipboard')">
<svg v-if="copiedRequestId === row.request_id" class="h-3.5 w-3.5" fill="none" stroke="currentColor" viewBox="0 0 24 24" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" /></svg>
<Icon v-else name="copy" size="sm" class="h-3.5 w-3.5" />
</button>
</div>
<span v-else class="text-gray-400 dark:text-gray-500">-</span>
<template #cell-user_agent="{ row }">
<span v-if="row.user_agent" class="text-sm text-gray-600 dark:text-gray-400 max-w-[150px] truncate block" :title="row.user_agent">{{ formatUserAgent(row.user_agent) }}</span>
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
</template>
<template #empty><EmptyState :message="t('usage.noRecords')" /></template>
@@ -228,7 +222,6 @@
import { ref, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import { formatDateTime } from '@/utils/format'
import { useAppStore } from '@/stores/app'
import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue'
@@ -236,8 +229,6 @@ import type { UsageLog } from '@/types'
defineProps(['data', 'loading'])
const { t } = useI18n()
const appStore = useAppStore()
const copiedRequestId = ref<string | null>(null)
// Tooltip state - cost
const tooltipVisible = ref(false)
@@ -262,7 +253,7 @@ const cols = computed(() => [
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
{ key: 'duration', label: t('usage.duration'), sortable: false },
{ key: 'created_at', label: t('usage.time'), sortable: true },
{ key: 'request_id', label: t('admin.usage.requestId'), sortable: false }
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false }
])
const formatCacheTokens = (tokens: number): string => {
@@ -271,23 +262,25 @@ const formatCacheTokens = (tokens: number): string => {
return tokens.toString()
}
const formatUserAgent = (ua: string): string => {
// 提取主要客户端标识
if (ua.includes('claude-cli')) return ua.match(/claude-cli\/[\d.]+/)?.[0] || 'Claude CLI'
if (ua.includes('Cursor')) return 'Cursor'
if (ua.includes('VSCode') || ua.includes('vscode')) return 'VS Code'
if (ua.includes('Continue')) return 'Continue'
if (ua.includes('Cline')) return 'Cline'
if (ua.includes('OpenAI')) return 'OpenAI SDK'
if (ua.includes('anthropic')) return 'Anthropic SDK'
// 截断过长的 UA
return ua.length > 30 ? ua.substring(0, 30) + '...' : ua
}
const formatDuration = (ms: number | null | undefined): string => {
if (ms == null) return '-'
if (ms < 1000) return `${ms}ms`
return `${(ms / 1000).toFixed(2)}s`
}
const copyRequestId = async (requestId: string) => {
try {
await navigator.clipboard.writeText(requestId)
copiedRequestId.value = requestId
appStore.showSuccess(t('admin.usage.requestIdCopied'))
setTimeout(() => { copiedRequestId.value = null }, 2000)
} catch {
appStore.showError(t('common.copyFailed'))
}
}
// Cost tooltip functions
const showTooltip = (event: MouseEvent, row: UsageLog) => {
const target = event.currentTarget as HTMLElement

View File

@@ -424,7 +424,8 @@ export default {
billingType: 'Billing',
balance: 'Balance',
subscription: 'Subscription',
imageUnit: ' images'
imageUnit: ' images',
userAgent: 'User-Agent'
},
// Redeem
@@ -856,6 +857,15 @@ export default {
imagePricing: {
title: 'Image Generation Pricing',
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
},
claudeCode: {
title: 'Claude Code Client Restriction',
tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.',
enabled: 'Claude Code Only',
disabled: 'Allow All Clients',
fallbackGroup: 'Fallback Group',
fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.',
noFallback: 'No Fallback (Reject)'
}
},
@@ -1560,6 +1570,7 @@ export default {
protocol: 'Protocol',
address: 'Address',
status: 'Status',
accounts: 'Accounts',
actions: 'Actions'
},
testConnection: 'Test Connection',

View File

@@ -421,7 +421,8 @@ export default {
billingType: '消费类型',
balance: '余额',
subscription: '订阅',
imageUnit: '张'
imageUnit: '张',
userAgent: 'User-Agent'
},
// Redeem
@@ -933,6 +934,15 @@ export default {
imagePricing: {
title: '图片生成计费',
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
},
claudeCode: {
title: 'Claude Code 客户端限制',
tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。',
enabled: '仅限 Claude Code',
disabled: '允许所有客户端',
fallbackGroup: '降级分组',
fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝',
noFallback: '不降级(直接拒绝)'
}
},
@@ -1646,6 +1656,7 @@ export default {
protocol: '协议',
address: '地址',
status: '状态',
accounts: '账号数',
actions: '操作',
nameLabel: '名称',
namePlaceholder: '请输入代理名称',

View File

@@ -263,6 +263,9 @@ export interface Group {
image_price_1k: number | null
image_price_2k: number | null
image_price_4k: number | null
// Claude Code 客户端限制
claude_code_only: boolean
fallback_group_id: number | null
account_count?: number
created_at: string
updated_at: string
@@ -298,6 +301,15 @@ export interface CreateGroupRequest {
platform?: GroupPlatform
rate_multiplier?: number
is_exclusive?: boolean
subscription_type?: SubscriptionType
daily_limit_usd?: number | null
weekly_limit_usd?: number | null
monthly_limit_usd?: number | null
image_price_1k?: number | null
image_price_2k?: number | null
image_price_4k?: number | null
claude_code_only?: boolean
fallback_group_id?: number | null
}
export interface UpdateGroupRequest {
@@ -307,6 +319,15 @@ export interface UpdateGroupRequest {
rate_multiplier?: number
is_exclusive?: boolean
status?: 'active' | 'inactive'
subscription_type?: SubscriptionType
daily_limit_usd?: number | null
weekly_limit_usd?: number | null
monthly_limit_usd?: number | null
image_price_1k?: number | null
image_price_2k?: number | null
image_price_4k?: number | null
claude_code_only?: boolean
fallback_group_id?: number | null
}
// ==================== Account & Proxy Types ====================
@@ -576,6 +597,9 @@ export interface UsageLog {
image_count: number
image_size: string | null
// User-Agent
user_agent: string | null
created_at: string
user?: User

View File

@@ -61,7 +61,7 @@
<!-- Login / Dashboard Button -->
<router-link
v-if="isAuthenticated"
to="/dashboard"
:to="dashboardPath"
class="inline-flex items-center gap-1.5 rounded-full bg-gray-900 py-1 pl-1 pr-2.5 transition-colors hover:bg-gray-800 dark:bg-gray-800 dark:hover:bg-gray-700"
>
<span
@@ -114,7 +114,7 @@
<!-- CTA Button -->
<div>
<router-link
:to="isAuthenticated ? '/dashboard' : '/login'"
:to="isAuthenticated ? dashboardPath : '/login'"
class="btn btn-primary px-8 py-3 text-base shadow-lg shadow-primary-500/30"
>
{{ isAuthenticated ? t('home.goToDashboard') : t('home.getStarted') }}
@@ -416,6 +416,8 @@ const githubUrl = 'https://github.com/Wei-Shaw/sub2api'
// Auth state
const isAuthenticated = computed(() => authStore.isAuthenticated)
const isAdmin = computed(() => authStore.isAdmin)
const dashboardPath = computed(() => isAdmin.value ? '/admin/dashboard' : '/dashboard')
const userInitial = computed(() => {
const user = authStore.user
if (!user || !user.email) return ''

View File

@@ -403,6 +403,62 @@
</div>
</div>
<!-- Claude Code 客户端限制 anthropic 平台 -->
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
<div class="mb-1.5 flex items-center gap-1">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.claudeCode.title') }}
</label>
<!-- Help Tooltip -->
<div class="group relative inline-flex">
<Icon
name="questionCircle"
size="sm"
:stroke-width="2"
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
/>
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
<p class="text-xs leading-relaxed text-gray-300">
{{ t('admin.groups.claudeCode.tooltip') }}
</p>
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
</div>
</div>
</div>
</div>
<div class="flex items-center gap-3">
<button
type="button"
@click="createForm.claude_code_only = !createForm.claude_code_only"
:class="[
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
createForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
]"
>
<span
:class="[
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
createForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
]"
/>
</button>
<span class="text-sm text-gray-500 dark:text-gray-400">
{{ createForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
</span>
</div>
<!-- 降级分组选择仅当启用 claude_code_only 时显示 -->
<div v-if="createForm.claude_code_only" class="mt-3">
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
<Select
v-model="createForm.fallback_group_id"
:options="fallbackGroupOptions"
:placeholder="t('admin.groups.claudeCode.noFallback')"
/>
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
</div>
</div>
</form>
<template #footer>
@@ -648,6 +704,62 @@
</div>
</div>
<!-- Claude Code 客户端限制 anthropic 平台 -->
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
<div class="mb-1.5 flex items-center gap-1">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.claudeCode.title') }}
</label>
<!-- Help Tooltip -->
<div class="group relative inline-flex">
<Icon
name="questionCircle"
size="sm"
:stroke-width="2"
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
/>
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
<p class="text-xs leading-relaxed text-gray-300">
{{ t('admin.groups.claudeCode.tooltip') }}
</p>
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
</div>
</div>
</div>
</div>
<div class="flex items-center gap-3">
<button
type="button"
@click="editForm.claude_code_only = !editForm.claude_code_only"
:class="[
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
editForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
]"
>
<span
:class="[
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
editForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
]"
/>
</button>
<span class="text-sm text-gray-500 dark:text-gray-400">
{{ editForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
</span>
</div>
<!-- 降级分组选择仅当启用 claude_code_only 时显示 -->
<div v-if="editForm.claude_code_only" class="mt-3">
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
<Select
v-model="editForm.fallback_group_id"
:options="fallbackGroupOptionsForEdit"
:placeholder="t('admin.groups.claudeCode.noFallback')"
/>
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
</div>
</div>
</form>
<template #footer>
@@ -774,6 +886,35 @@ const subscriptionTypeOptions = computed(() => [
{ value: 'subscription', label: t('admin.groups.subscription.subscription') }
])
// 降级分组选项(创建时)- 仅包含 anthropic 平台且未启用 claude_code_only 的分组
const fallbackGroupOptions = computed(() => {
const options: { value: number | null; label: string }[] = [
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
]
const eligibleGroups = groups.value.filter(
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active'
)
eligibleGroups.forEach((g) => {
options.push({ value: g.id, label: g.name })
})
return options
})
// 降级分组选项(编辑时)- 排除自身
const fallbackGroupOptionsForEdit = computed(() => {
const options: { value: number | null; label: string }[] = [
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
]
const currentId = editingGroup.value?.id
const eligibleGroups = groups.value.filter(
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active' && g.id !== currentId
)
eligibleGroups.forEach((g) => {
options.push({ value: g.id, label: g.name })
})
return options
})
const groups = ref<Group[]>([])
const loading = ref(false)
const searchQuery = ref('')
@@ -821,7 +962,10 @@ const createForm = reactive({
// 图片生成计费配置(仅 antigravity 平台使用)
image_price_1k: null as number | null,
image_price_2k: null as number | null,
image_price_4k: null as number | null
image_price_4k: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false,
fallback_group_id: null as number | null
})
const editForm = reactive({
@@ -838,7 +982,10 @@ const editForm = reactive({
// 图片生成计费配置(仅 antigravity 平台使用)
image_price_1k: null as number | null,
image_price_2k: null as number | null,
image_price_4k: null as number | null
image_price_4k: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false,
fallback_group_id: null as number | null
})
// 根据分组类型返回不同的删除确认消息
@@ -908,6 +1055,8 @@ const closeCreateModal = () => {
createForm.image_price_1k = null
createForm.image_price_2k = null
createForm.image_price_4k = null
createForm.claude_code_only = false
createForm.fallback_group_id = null
}
const handleCreateGroup = async () => {
@@ -949,6 +1098,8 @@ const handleEdit = (group: Group) => {
editForm.image_price_1k = group.image_price_1k
editForm.image_price_2k = group.image_price_2k
editForm.image_price_4k = group.image_price_4k
editForm.claude_code_only = group.claude_code_only || false
editForm.fallback_group_id = group.fallback_group_id
showEditModal.value = true
}
@@ -966,7 +1117,12 @@ const handleUpdateGroup = async () => {
submitting.value = true
try {
await adminAPI.groups.update(editingGroup.value.id, editForm)
// 转换 fallback_group_id: null -> 0 (后端使用 0 表示清除)
const payload = {
...editForm,
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id
}
await adminAPI.groups.update(editingGroup.value.id, payload)
appStore.showSuccess(t('admin.groups.groupUpdated'))
closeEditModal()
loadGroups()

View File

@@ -85,6 +85,14 @@
</span>
</template>
<template #cell-account_count="{ value }">
<span
class="inline-flex items-center rounded bg-gray-100 px-2 py-0.5 text-xs font-medium text-gray-800 dark:bg-dark-600 dark:text-gray-300"
>
{{ t('admin.groups.accountsCount', { count: value || 0 }) }}
</span>
</template>
<template #cell-actions="{ row }">
<div class="flex items-center gap-1">
<button
@@ -534,6 +542,7 @@ const columns = computed<Column[]>(() => [
{ key: 'name', label: t('admin.proxies.columns.name'), sortable: true },
{ key: 'protocol', label: t('admin.proxies.columns.protocol'), sortable: true },
{ key: 'address', label: t('admin.proxies.columns.address'), sortable: false },
{ key: 'account_count', label: t('admin.proxies.columns.accounts'), sortable: true },
{ key: 'status', label: t('admin.proxies.columns.status'), sortable: true },
{ key: 'actions', label: t('admin.proxies.columns.actions'), sortable: false }
])

View File

@@ -96,7 +96,7 @@ const exportToExcel = async () => {
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
t('usage.rate'), t('usage.original'), t('usage.billed'),
t('usage.billingType'), t('usage.firstToken'), t('usage.duration'),
t('admin.usage.requestId')
t('admin.usage.requestId'), t('usage.userAgent')
]
const rows = all.map(log => [
log.created_at,
@@ -120,7 +120,8 @@ const exportToExcel = async () => {
log.billing_type === 1 ? t('usage.subscription') : t('usage.balance'),
log.first_token_ms ?? '',
log.duration_ms,
log.request_id || ''
log.request_id || '',
log.user_agent || ''
])
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
const wb = XLSX.utils.book_new()

View File

@@ -308,6 +308,11 @@
}}</span>
</template>
<template #cell-user_agent="{ row }">
<span v-if="row.user_agent" class="text-sm text-gray-600 dark:text-gray-400 max-w-[150px] truncate block" :title="row.user_agent">{{ formatUserAgent(row.user_agent) }}</span>
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
</template>
<template #empty>
<EmptyState :message="t('usage.noRecords')" />
</template>
@@ -480,7 +485,8 @@ const columns = computed<Column[]>(() => [
{ key: 'billing_type', label: t('usage.billingType'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
{ key: 'duration', label: t('usage.duration'), sortable: false },
{ key: 'created_at', label: t('usage.time'), sortable: true }
{ key: 'created_at', label: t('usage.time'), sortable: true },
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false }
])
const usageLogs = ref<UsageLog[]>([])
@@ -545,6 +551,19 @@ const formatDuration = (ms: number): string => {
return `${(ms / 1000).toFixed(2)}s`
}
const formatUserAgent = (ua: string): string => {
// 提取主要客户端标识
if (ua.includes('claude-cli')) return ua.match(/claude-cli\/[\d.]+/)?.[0] || 'Claude CLI'
if (ua.includes('Cursor')) return 'Cursor'
if (ua.includes('VSCode') || ua.includes('vscode')) return 'VS Code'
if (ua.includes('Continue')) return 'Continue'
if (ua.includes('Cline')) return 'Cline'
if (ua.includes('OpenAI')) return 'OpenAI SDK'
if (ua.includes('anthropic')) return 'Anthropic SDK'
// 截断过长的 UA
return ua.length > 30 ? ua.substring(0, 30) + '...' : ua
}
const formatTokens = (value: number): string => {
if (value >= 1_000_000_000) {
return `${(value / 1_000_000_000).toFixed(2)}B`