mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 16:30:22 +08:00
Compare commits
243 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04b4d7de3a | ||
|
|
b11cfdc11f | ||
|
|
cdddbec629 | ||
|
|
b3fe0506fb | ||
|
|
8ca0e2772e | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
5f6e929d61 | ||
|
|
4caa3c2701 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
a9ecd7bcc6 | ||
|
|
f89465fb39 | ||
|
|
440c3f46a7 | ||
|
|
c746964936 | ||
|
|
b2d6879b3f | ||
|
|
ce4095904e | ||
|
|
0d2dcbff11 | ||
|
|
02dca78dbe | ||
|
|
8df299b767 | ||
|
|
95cf59b2f6 | ||
|
|
a0f643da0e | ||
|
|
99331a5285 | ||
|
|
9245e197a8 | ||
|
|
5fa1d0978f | ||
|
|
0bb683a6f1 | ||
|
|
cb3958bac3 | ||
|
|
91ebf95efa | ||
|
|
7895dd65b2 | ||
|
|
09f8894906 | ||
|
|
45fdef9da7 | ||
|
|
704f256bd7 | ||
|
|
a6026e7ac4 | ||
|
|
1e03b2974a | ||
|
|
daa7c783b9 | ||
|
|
8a82a2a648 | ||
|
|
c3ac68af2a | ||
|
|
ec3897b981 | ||
|
|
62486cee37 | ||
|
|
7c5746ffbc | ||
|
|
47f7b0213b | ||
|
|
8df42f7aab | ||
|
|
d666e05a6d | ||
|
|
c37edf2de5 | ||
|
|
bb9af2465e | ||
|
|
7af00864b3 | ||
|
|
81903e87e3 | ||
|
|
0b96c7a65e | ||
|
|
34ccfe45ea | ||
|
|
d4231150a9 | ||
|
|
2268e93aec | ||
|
|
c976916b6d | ||
|
|
0bbd003a18 | ||
|
|
336a844712 | ||
|
|
51e7f262bd | ||
|
|
99663a3f20 | ||
|
|
4d88248091 | ||
|
|
c143367e15 | ||
|
|
ca44389baa | ||
|
|
05c2a65ef0 | ||
|
|
3f21d204a7 | ||
|
|
c848f950ce | ||
|
|
87bd765a57 | ||
|
|
49325a769e | ||
|
|
238d86f502 | ||
|
|
7064063230 | ||
|
|
b1de4352a8 | ||
|
|
7bf5c1cbcb | ||
|
|
411e24146d | ||
|
|
c7392fc80b | ||
|
|
c37c68a341 | ||
|
|
9230d3cbc9 | ||
|
|
19925e22d9 | ||
|
|
65e0c1b258 | ||
|
|
94bdde32bb | ||
|
|
14c80d26c6 | ||
|
|
9555a99d1c | ||
|
|
0e69895603 | ||
|
|
cc3cf1d70a | ||
|
|
3382d496e3 | ||
|
|
e0b4b00dc1 | ||
|
|
81d896bf78 | ||
|
|
ec576fdbde | ||
|
|
741eae59bb | ||
|
|
505494b378 | ||
|
|
c1033c12bd | ||
|
|
b789333b68 | ||
|
|
61ef73cb12 | ||
|
|
e71be7e0f1 | ||
|
|
0302c03864 | ||
|
|
d21fe54d55 | ||
|
|
a70d3ff82d | ||
|
|
574359f1df | ||
|
|
6da2f54e50 | ||
|
|
886464b2e9 | ||
|
|
396044e354 | ||
|
|
3d15202124 | ||
|
|
756b09b6b8 | ||
|
|
f4d3fadd6f | ||
|
|
1fb6e9e830 | ||
|
|
78ac6a7a29 | ||
|
|
efe8810dff | ||
|
|
5c07e11473 | ||
|
|
d552ad7673 | ||
|
|
496173da1f | ||
|
|
1cdaf33272 | ||
|
|
e2b3969492 | ||
|
|
8661bf8837 | ||
|
|
292fa7a6d2 | ||
|
|
39d7300a8e | ||
|
|
0ad20c9489 | ||
|
|
7eb3b23ddf | ||
|
|
5b06542193 | ||
|
|
38dca4f787 | ||
|
|
737d1ecf5b | ||
|
|
f819cef6d5 | ||
|
|
facae2a6db | ||
|
|
8b021c099d | ||
|
|
8a625188ce | ||
|
|
c0cfa6acde | ||
|
|
0913cfc082 | ||
|
|
59e8465325 | ||
|
|
be6e8ff77b | ||
|
|
ebb85cf843 | ||
|
|
ef959bc3c6 | ||
|
|
8cb7356bbc | ||
|
|
496545188a | ||
|
|
739c80227a | ||
|
|
6218eefd61 | ||
|
|
5715587baf | ||
|
|
ae770a625b | ||
|
|
428ee065d3 | ||
|
|
320ca28f90 | ||
|
|
5e518f5fbd | ||
|
|
24dcba1d72 | ||
|
|
1abc688cad | ||
|
|
34936189d8 | ||
|
|
daf7bf3e8b | ||
|
|
3a9f1c5796 | ||
|
|
bb1e205516 | ||
|
|
9af4a55176 | ||
|
|
51e903c34e | ||
|
|
7f03319646 | ||
|
|
0c33d18a4d | ||
|
|
a747c63b8e | ||
|
|
a03c361b04 | ||
|
|
807d0018ef | ||
|
|
850e267763 | ||
|
|
c75ae56f10 | ||
|
|
caaed775aa | ||
|
|
d11b295729 | ||
|
|
91d0059f8d | ||
|
|
44693d0dfb | ||
|
|
c722212e12 | ||
|
|
f12da65962 | ||
|
|
92745f7534 | ||
|
|
f2917aeaf8 | ||
|
|
a1e2ffd586 | ||
|
|
78a9705fad | ||
|
|
5b6da04a02 | ||
|
|
4661c2f90f | ||
|
|
90e4328885 | ||
|
|
130112a84a | ||
|
|
b368bb6ea1 | ||
|
|
9ecb6211d6 | ||
|
|
79fba9c8d3 | ||
|
|
37c76a93ab | ||
|
|
86e600aa52 | ||
|
|
6a1c28b70e | ||
|
|
9375f1809c | ||
|
|
f176150c93 | ||
|
|
30d25084f0 | ||
|
|
91ad94d941 | ||
|
|
57a778dccf | ||
|
|
f702c66659 | ||
|
|
a095468850 | ||
|
|
f9b6a20995 | ||
|
|
f2770da880 | ||
|
|
d269659e61 | ||
|
|
c4d6715443 | ||
|
|
fc4a1c5433 | ||
|
|
6bdd580b3f | ||
|
|
9cf4882f4c | ||
|
|
406dad998d | ||
|
|
8b0db22c18 | ||
|
|
f06048eccf | ||
|
|
05f5a8b61d | ||
|
|
662625a091 | ||
|
|
6328e69441 | ||
|
|
425dfb80d9 | ||
|
|
4c1fd570f0 | ||
|
|
345f853b5d | ||
|
|
100a70f87c | ||
|
|
18b591bc3b | ||
|
|
6a52b24369 | ||
|
|
228aca9523 | ||
|
|
7e4637cd70 | ||
|
|
3e3c015efa | ||
|
|
30c30b1712 | ||
|
|
e666356483 | ||
|
|
3710bc883b | ||
|
|
64f60d15b0 | ||
|
|
bc4a044337 | ||
|
|
cb233bfa66 | ||
|
|
d46059a735 | ||
|
|
da2fbd9924 | ||
|
|
084e0adb34 | ||
|
|
3bddbb6afe |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,6 +17,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
@@ -36,6 +37,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -78,6 +78,7 @@ Desktop.ini
|
|||||||
# ===================
|
# ===================
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
|
logs/
|
||||||
*.tmp
|
*.tmp
|
||||||
*.temp
|
*.temp
|
||||||
*.log
|
*.log
|
||||||
@@ -128,8 +129,15 @@ deploy/docker-compose.override.yml
|
|||||||
vite.config.js
|
vite.config.js
|
||||||
docs/*
|
docs/*
|
||||||
.serena/
|
.serena/
|
||||||
|
|
||||||
|
# ===================
|
||||||
|
# 压测工具
|
||||||
|
# ===================
|
||||||
|
tools/loadtest/
|
||||||
|
# Antigravity Manager
|
||||||
|
Antigravity-Manager/
|
||||||
|
antigravity_projectid_fix.patch
|
||||||
.codex/
|
.codex/
|
||||||
frontend/coverage/
|
frontend/coverage/
|
||||||
aicodex
|
aicodex
|
||||||
output/
|
output/
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.1.88
|
0.1.96.1
|
||||||
|
|||||||
@@ -62,26 +62,28 @@ type Group struct {
|
|||||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||||
// 是否仅允许 Claude Code 客户端
|
// allow Claude Code client only
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// fallback group for non-Claude-Code requests
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
// 无效请求兜底使用的分组 ID
|
// fallback group for invalid request
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
// model routing config: pattern -> account ids
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
// 是否启用模型路由配置
|
// whether model routing is enabled
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||||
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
// whether MCP XML prompt injection is enabled
|
||||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||||
// 支持的模型系列:claude, gemini_text, gemini_image
|
// supported model scopes: claude, gemini_text, gemini_image
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
// 分组显示排序,数值越小越靠前
|
// group display order, lower comes first
|
||||||
SortOrder int `json:"sort_order,omitempty"`
|
SortOrder int `json:"sort_order,omitempty"`
|
||||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||||
|
// simulate claude usage as claude-max style (1h cache write)
|
||||||
|
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -190,7 +192,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
@@ -431,6 +433,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.DefaultMappedModel = value.String
|
_m.DefaultMappedModel = value.String
|
||||||
}
|
}
|
||||||
|
case group.FieldSimulateClaudeMaxEnabled:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.SimulateClaudeMaxEnabled = value.Bool
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -630,6 +638,9 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("default_mapped_model=")
|
builder.WriteString("default_mapped_model=")
|
||||||
builder.WriteString(_m.DefaultMappedModel)
|
builder.WriteString(_m.DefaultMappedModel)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("simulate_claude_max_enabled=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ const (
|
|||||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||||
FieldDefaultMappedModel = "default_mapped_model"
|
FieldDefaultMappedModel = "default_mapped_model"
|
||||||
|
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
|
||||||
|
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -186,6 +188,7 @@ var Columns = []string{
|
|||||||
FieldSortOrder,
|
FieldSortOrder,
|
||||||
FieldAllowMessagesDispatch,
|
FieldAllowMessagesDispatch,
|
||||||
FieldDefaultMappedModel,
|
FieldDefaultMappedModel,
|
||||||
|
FieldSimulateClaudeMaxEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -259,6 +262,8 @@ var (
|
|||||||
DefaultDefaultMappedModel string
|
DefaultDefaultMappedModel string
|
||||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
DefaultMappedModelValidator func(string) error
|
DefaultMappedModelValidator func(string) error
|
||||||
|
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
|
||||||
|
DefaultSimulateClaudeMaxEnabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -419,6 +424,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
|
||||||
|
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -205,6 +205,11 @@ func DefaultMappedModel(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
|
||||||
|
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1555,6 +1560,16 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
|
||||||
|
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
|
||||||
|
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -452,6 +452,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
|
||||||
|
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetSimulateClaudeMaxEnabled(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -649,6 +663,10 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultMappedModel
|
v := group.DefaultDefaultMappedModel
|
||||||
_c.mutation.SetDefaultMappedModel(v)
|
_c.mutation.SetDefaultMappedModel(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||||
|
v := group.DefaultSimulateClaudeMaxEnabled
|
||||||
|
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -730,6 +748,9 @@ func (_c *GroupCreate) check() error {
|
|||||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||||
|
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -885,6 +906,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
_node.DefaultMappedModel = value
|
_node.DefaultMappedModel = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||||
|
_node.SimulateClaudeMaxEnabled = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1599,6 +1624,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
|
||||||
|
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -2295,6 +2332,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSimulateClaudeMaxEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -3157,6 +3208,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSimulateClaudeMaxEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -653,6 +653,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
|
||||||
|
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1149,6 +1163,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -2081,6 +2098,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -2607,6 +2638,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -410,6 +410,7 @@ var (
|
|||||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
|
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
|
|||||||
@@ -8252,6 +8252,7 @@ type GroupMutation struct {
|
|||||||
addsort_order *int
|
addsort_order *int
|
||||||
allow_messages_dispatch *bool
|
allow_messages_dispatch *bool
|
||||||
default_mapped_model *string
|
default_mapped_model *string
|
||||||
|
simulate_claude_max_enabled *bool
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -10068,6 +10069,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
|
|||||||
m.default_mapped_model = nil
|
m.default_mapped_model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||||
|
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
|
||||||
|
m.simulate_claude_max_enabled = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
|
||||||
|
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
|
||||||
|
v := m.simulate_claude_max_enabled
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.SimulateClaudeMaxEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
|
||||||
|
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
|
||||||
|
m.simulate_claude_max_enabled = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -10426,7 +10463,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 32)
|
fields := make([]string, 0, 33)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -10523,6 +10560,9 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.default_mapped_model != nil {
|
if m.default_mapped_model != nil {
|
||||||
fields = append(fields, group.FieldDefaultMappedModel)
|
fields = append(fields, group.FieldDefaultMappedModel)
|
||||||
}
|
}
|
||||||
|
if m.simulate_claude_max_enabled != nil {
|
||||||
|
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -10595,6 +10635,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.AllowMessagesDispatch()
|
return m.AllowMessagesDispatch()
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.DefaultMappedModel()
|
return m.DefaultMappedModel()
|
||||||
|
case group.FieldSimulateClaudeMaxEnabled:
|
||||||
|
return m.SimulateClaudeMaxEnabled()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -10668,6 +10710,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldAllowMessagesDispatch(ctx)
|
return m.OldAllowMessagesDispatch(ctx)
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.OldDefaultMappedModel(ctx)
|
return m.OldDefaultMappedModel(ctx)
|
||||||
|
case group.FieldSimulateClaudeMaxEnabled:
|
||||||
|
return m.OldSimulateClaudeMaxEnabled(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -10901,6 +10945,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetDefaultMappedModel(v)
|
m.SetDefaultMappedModel(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldSimulateClaudeMaxEnabled:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetSimulateClaudeMaxEnabled(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -11334,6 +11385,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
m.ResetDefaultMappedModel()
|
m.ResetDefaultMappedModel()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldSimulateClaudeMaxEnabled:
|
||||||
|
m.ResetSimulateClaudeMaxEnabled()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -463,6 +463,10 @@ func init() {
|
|||||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||||
|
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
|
||||||
|
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
|
||||||
|
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
|
||||||
|
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
|
||||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||||
_ = idempotencyrecordMixinFields0
|
_ = idempotencyrecordMixinFields0
|
||||||
|
|||||||
@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
|
|||||||
|
|
||||||
func (Group) Fields() []ent.Field {
|
func (Group) Fields() []ent.Field {
|
||||||
return []ent.Field{
|
return []ent.Field{
|
||||||
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
|
||||||
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
|
||||||
field.String("name").
|
field.String("name").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
Default(domain.StatusActive),
|
Default(domain.StatusActive),
|
||||||
|
|
||||||
// Subscription-related fields (added by migration 003)
|
|
||||||
field.String("platform").
|
field.String("platform").
|
||||||
MaxLen(50).
|
MaxLen(50).
|
||||||
Default(domain.PlatformAnthropic),
|
Default(domain.PlatformAnthropic),
|
||||||
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int("default_validity_days").
|
field.Int("default_validity_days").
|
||||||
Default(30),
|
Default(30),
|
||||||
|
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
|
||||||
field.Float("image_price_1k").
|
field.Float("image_price_1k").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
// Sora 按次计费配置(阶段 1)
|
|
||||||
field.Float("sora_image_price_360").
|
field.Float("sora_image_price_360").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -109,45 +104,38 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int64("sora_storage_quota_bytes").
|
field.Int64("sora_storage_quota_bytes").
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|
||||||
// Claude Code 客户端限制 (added by migration 029)
|
|
||||||
field.Bool("claude_code_only").
|
field.Bool("claude_code_only").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("是否仅允许 Claude Code 客户端"),
|
Comment("allow Claude Code client only"),
|
||||||
field.Int64("fallback_group_id").
|
field.Int64("fallback_group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
Comment("fallback group for non-Claude-Code requests"),
|
||||||
field.Int64("fallback_group_id_on_invalid_request").
|
field.Int64("fallback_group_id_on_invalid_request").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("无效请求兜底使用的分组 ID"),
|
Comment("fallback group for invalid request"),
|
||||||
|
|
||||||
// 模型路由配置 (added by migration 040)
|
|
||||||
field.JSON("model_routing", map[string][]int64{}).
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
Optional().
|
Optional().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
Comment("model routing config: pattern -> account ids"),
|
||||||
|
|
||||||
// 模型路由开关 (added by migration 041)
|
|
||||||
field.Bool("model_routing_enabled").
|
field.Bool("model_routing_enabled").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("是否启用模型路由配置"),
|
Comment("whether model routing is enabled"),
|
||||||
|
|
||||||
// MCP XML 协议注入开关 (added by migration 042)
|
|
||||||
field.Bool("mcp_xml_inject").
|
field.Bool("mcp_xml_inject").
|
||||||
Default(true).
|
Default(true).
|
||||||
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
Comment("whether MCP XML prompt injection is enabled"),
|
||||||
|
|
||||||
// 支持的模型系列 (added by migration 046)
|
|
||||||
field.JSON("supported_model_scopes", []string{}).
|
field.JSON("supported_model_scopes", []string{}).
|
||||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
||||||
|
|
||||||
// 分组排序 (added by migration 052)
|
|
||||||
field.Int("sort_order").
|
field.Int("sort_order").
|
||||||
Default(0).
|
Default(0).
|
||||||
Comment("分组显示排序,数值越小越靠前"),
|
Comment("group display order, lower comes first"),
|
||||||
|
|
||||||
// OpenAI Messages 调度配置 (added by migration 069)
|
// OpenAI Messages 调度配置 (added by migration 069)
|
||||||
field.Bool("allow_messages_dispatch").
|
field.Bool("allow_messages_dispatch").
|
||||||
@@ -157,6 +145,9 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
Default("").
|
Default("").
|
||||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||||
|
field.Bool("simulate_claude_max_enabled").
|
||||||
|
Default(false).
|
||||||
|
Comment("simulate claude usage as claude-max style (1h cache write)"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,14 +163,11 @@ func (Group) Edges() []ent.Edge {
|
|||||||
edge.From("allowed_users", User.Type).
|
edge.From("allowed_users", User.Type).
|
||||||
Ref("allowed_groups").
|
Ref("allowed_groups").
|
||||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
|
||||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Group) Indexes() []ent.Index {
|
func (Group) Indexes() []ent.Index {
|
||||||
return []ent.Index{
|
return []ent.Index{
|
||||||
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
|
||||||
index.Fields("status"),
|
index.Fields("status"),
|
||||||
index.Fields("platform"),
|
index.Fields("platform"),
|
||||||
index.Fields("subscription_type"),
|
index.Fields("subscription_type"),
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
@@ -137,6 +138,8 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
|||||||
@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
@@ -180,6 +182,7 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
@@ -199,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
|||||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -281,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
|
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||||
|
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -333,6 +342,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
|||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -341,8 +352,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||||
@@ -432,11 +441,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||||
|
|||||||
@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
// Gemini 2.5 白名单
|
// Gemini 2.5 白名单
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
// Gemini 3 白名单
|
// Gemini 3 白名单
|
||||||
"gemini-3-flash": "gemini-3-flash",
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
|||||||
@@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
|||||||
// TestAccountRequest represents the request body for testing an account
|
// TestAccountRequest represents the request body for testing an account
|
||||||
type TestAccountRequest struct {
|
type TestAccountRequest struct {
|
||||||
ModelID string `json:"model_id"`
|
ModelID string `json:"model_id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSRequest struct {
|
type SyncFromCRSRequest struct {
|
||||||
@@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
|||||||
_ = c.ShouldBindJSON(&req)
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
// Use AccountTestService to test the account with SSE streaming
|
// Use AccountTestService to test the account with SSE streaming
|
||||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||||
// Error already sent via SSE, just log
|
// Error already sent via SSE, just log
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1337,6 +1338,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
|
"details": gin.H{
|
||||||
|
"group_id": mixedErr.GroupID,
|
||||||
|
"group_name": mixedErr.GroupName,
|
||||||
|
"current_platform": mixedErr.CurrentPlatform,
|
||||||
|
"other_platform": mixedErr.OtherPlatform,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
require.Contains(t, resp["message"], "claude-max")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
require.Contains(t, resp["message"], "claude-max")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
|
|||||||
@@ -175,6 +175,10 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"models": stats,
|
"models": stats,
|
||||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
response.Error(c, 500, "Failed to get group statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"groups": stats,
|
"groups": stats,
|
||||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
|||||||
limit = 5
|
limit = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage trend")
|
response.Error(c, 500, "Failed to get API key usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
|||||||
limit = 12
|
limit = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
response.Error(c, 500, "Failed to get user usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
|
|||||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardUsageRepoCacheProbe struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
trendCalls atomic.Int32
|
||||||
|
usersTrendCalls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
r.trendCalls.Add(1)
|
||||||
|
return []usagestats.TrendDataPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
Requests: 1,
|
||||||
|
TotalTokens: 2,
|
||||||
|
Cost: 3,
|
||||||
|
ActualCost: 4,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
limit int,
|
||||||
|
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
|
r.usersTrendCalls.Add(1)
|
||||||
|
return []usagestats.UserUsageTrendPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
UserID: 1,
|
||||||
|
Email: "cache@test.dev",
|
||||||
|
Requests: 2,
|
||||||
|
Tokens: 20,
|
||||||
|
Cost: 2,
|
||||||
|
ActualCost: 1,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetDashboardReadCachesForTest() {
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||||
|
}
|
||||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardModelGroupCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardEntityTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheStatusValue(hit bool) string {
|
||||||
|
if hit {
|
||||||
|
return "hit"
|
||||||
|
}
|
||||||
|
return "miss"
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalDashboardCacheKey(value any) string {
|
||||||
|
raw, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||||
|
typed, ok := payload.(T)
|
||||||
|
if !ok {
|
||||||
|
var zero T
|
||||||
|
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||||
|
}
|
||||||
|
return typed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUsageTrendCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getModelStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.ModelStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getGroupStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.GroupStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
cacheKey := string(keyRaw)
|
cacheKey := string(keyRaw)
|
||||||
|
|
||||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||||
if cached.ETag != "" {
|
return h.buildSnapshotV2Response(
|
||||||
c.Header("ETag", cached.ETag)
|
c.Request.Context(),
|
||||||
c.Header("Vary", "If-None-Match")
|
startTime,
|
||||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
endTime,
|
||||||
c.Status(http.StatusNotModified)
|
granularity,
|
||||||
return
|
filters,
|
||||||
}
|
includeStats,
|
||||||
}
|
includeTrend,
|
||||||
c.Header("X-Snapshot-Cache", "hit")
|
includeModels,
|
||||||
response.Success(c, cached.Payload)
|
includeGroups,
|
||||||
|
includeUsersTrend,
|
||||||
|
usersTrendLimit,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
filters *dashboardSnapshotV2Filters,
|
||||||
|
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||||
|
usersTrendLimit int,
|
||||||
|
) (*dashboardSnapshotV2Response, error) {
|
||||||
resp := &dashboardSnapshotV2Response{
|
resp := &dashboardSnapshotV2Response{
|
||||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
StartDate: startTime.Format("2006-01-02"),
|
StartDate: startTime.Format("2006-01-02"),
|
||||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeStats {
|
if includeStats {
|
||||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
return nil, errors.New("failed to get dashboard statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Stats = &dashboardSnapshotV2Stats{
|
resp.Stats = &dashboardSnapshotV2Stats{
|
||||||
DashboardStats: *stats,
|
DashboardStats: *stats,
|
||||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeTrend {
|
if includeTrend {
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
trend, _, err := h.getUsageTrendCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
granularity,
|
granularity,
|
||||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
return nil, errors.New("failed to get usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Trend = trend
|
resp.Trend = trend
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeModels {
|
if includeModels {
|
||||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
models, _, err := h.getModelStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
return nil, errors.New("failed to get model statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Models = models
|
resp.Models = models
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeGroups {
|
if includeGroups {
|
||||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
groups, _, err := h.getGroupStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
return nil, errors.New("failed to get group statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Groups = groups
|
resp.Groups = groups
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeUsersTrend {
|
if includeUsersTrend {
|
||||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||||
c.Request.Context(),
|
|
||||||
startTime,
|
|
||||||
endTime,
|
|
||||||
granularity,
|
|
||||||
usersTrendLimit,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
return nil, errors.New("failed to get user usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.UsersTrend = usersTrend
|
resp.UsersTrend = usersTrend
|
||||||
}
|
}
|
||||||
|
|
||||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
return resp, nil
|
||||||
if cached.ETag != "" {
|
|
||||||
c.Header("ETag", cached.ETag)
|
|
||||||
c.Header("Vary", "If-None-Match")
|
|
||||||
}
|
|
||||||
c.Header("X-Snapshot-Cache", "miss")
|
|
||||||
response.Success(c, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||||
|
|||||||
@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
|
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -84,9 +85,10 @@ type UpdateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
|
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -207,6 +209,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -260,6 +263,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -335,6 +339,27 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
|||||||
response.Paginated(c, outKeys, total, page, pageSize)
|
response.Paginated(c, outKeys, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||||
|
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if entries == nil {
|
||||||
|
entries = []service.UserGroupRateEntry{}
|
||||||
|
}
|
||||||
|
response.Success(c, entries)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
|||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent",
|
"memory_usage_percent",
|
||||||
"concurrency_queue_depth",
|
"concurrency_queue_depth",
|
||||||
|
"group_available_accounts",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_rate_limited_count",
|
||||||
|
"account_error_count",
|
||||||
|
"account_error_ratio",
|
||||||
|
"overload_account_count",
|
||||||
}
|
}
|
||||||
|
|
||||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
|||||||
"error_rate",
|
"error_rate",
|
||||||
"upstream_error_rate",
|
"upstream_error_rate",
|
||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent":
|
"memory_usage_percent",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_error_ratio":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
type snapshotCacheEntry struct {
|
type snapshotCacheEntry struct {
|
||||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
items map[string]snapshotCacheEntry
|
items map[string]snapshotCacheEntry
|
||||||
|
sf singleflight.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshotCacheLoadResult struct {
|
||||||
|
Entry snapshotCacheEntry
|
||||||
|
Hit bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
|||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||||
|
if load == nil {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return entry, true, nil
|
||||||
|
}
|
||||||
|
if c == nil || key == "" {
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
return c.Set(key, payload), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||||
|
}
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
result, ok := value.(snapshotCacheLoadResult)
|
||||||
|
if !ok {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
return result.Entry, result.Hit, nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildETagFromAny(payload any) string {
|
func buildETagFromAny(payload any) string {
|
||||||
raw, err := json.Marshal(payload)
|
raw, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
|||||||
require.Empty(t, etag)
|
require.Empty(t, etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
|
||||||
|
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"hello": "world"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, hit)
|
||||||
|
require.NotEmpty(t, entry.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
|
||||||
|
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"unexpected": "value"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hit)
|
||||||
|
require.Equal(t, entry.ETag, entry2.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
start := make(chan struct{})
|
||||||
|
const callers = 8
|
||||||
|
errCh := make(chan error, callers)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(callers)
|
||||||
|
for range callers {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
return "value", nil
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
for err := range errCh {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||||
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
|
Daily bool `json:"daily"`
|
||||||
|
Weekly bool `json:"weekly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||||
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid subscription ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req ResetSubscriptionQuotaRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.Daily && !req.Weekly {
|
||||||
|
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||||
|
}
|
||||||
|
|
||||||
// Revoke handles revoking a subscription
|
// Revoke handles revoking a subscription
|
||||||
// DELETE /api/v1/admin/subscriptions/:id
|
// DELETE /api/v1/admin/subscriptions/:id
|
||||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||||
|
|||||||
@@ -135,14 +135,15 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||||
AccountCount: g.AccountCount,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SortOrder: g.SortOrder,
|
AccountCount: g.AccountCount,
|
||||||
|
SortOrder: g.SortOrder,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
|
|||||||
@@ -117,6 +117,8 @@ type AdminGroup struct {
|
|||||||
|
|
||||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
|
// Claude usage 模拟开关(仅管理员可见)
|
||||||
|
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||||
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|||||||
@@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
|
ParsedRequest: parsedReq,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
@@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// ===== 用户消息串行队列 END =====
|
// ===== 用户消息串行队列 END =====
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
|
c.Set("parsed_request", parsedReq)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
@@ -741,6 +743,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
|
ParsedRequest: parsedReq,
|
||||||
APIKey: currentAPIKey,
|
APIKey: currentAPIKey,
|
||||||
User: currentAPIKey.User,
|
User: currentAPIKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
|
|||||||
290
backend/internal/handler/openai_chat_completions.go
Normal file
290
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||||
|
// POST /v1/chat/completions
|
||||||
|
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||||
|
streamStarted := false
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.chat_completions",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqModel := modelResult.String()
|
||||||
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
|
||||||
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||||
|
|
||||||
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
if h.errorPassthroughService != nil {
|
||||||
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
|
routingStart := time.Now()
|
||||||
|
|
||||||
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||||
|
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||||
|
|
||||||
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
|
switchCount := 0
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int)
|
||||||
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", "")
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
defaultModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if defaultModel != "" && defaultModel != reqModel {
|
||||||
|
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||||
|
zap.String("default_mapped_model", defaultModel),
|
||||||
|
)
|
||||||
|
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
defaultModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err == nil && selection != nil {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||||
|
} else {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account := selection.Account
|
||||||
|
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||||
|
_ = scheduleDecision
|
||||||
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
|
|
||||||
|
defaultMappedModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||||
|
defaultMappedModel = fallbackModel
|
||||||
|
}
|
||||||
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
// Pool mode: retry on the same account
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
retryLimit := account.GetPoolModeRetryCount()
|
||||||
|
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("retry_limit", retryLimit),
|
||||||
|
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||||
|
)
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(sameAccountRetryDelay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
|
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
} else {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
}); err != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.String("model", reqModel),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
reqLog.Debug("openai_chat_completions.request_completed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
const (
|
const (
|
||||||
opsErrorLogTimeout = 5 * time.Second
|
opsErrorLogTimeout = 5 * time.Second
|
||||||
opsErrorLogDrainTimeout = 10 * time.Second
|
opsErrorLogDrainTimeout = 10 * time.Second
|
||||||
|
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||||
|
|
||||||
opsErrorLogMinWorkerCount = 4
|
opsErrorLogMinWorkerCount = 4
|
||||||
opsErrorLogMaxWorkerCount = 32
|
opsErrorLogMaxWorkerCount = 32
|
||||||
@@ -38,6 +39,7 @@ const (
|
|||||||
opsErrorLogQueueSizePerWorker = 128
|
opsErrorLogQueueSizePerWorker = 128
|
||||||
opsErrorLogMinQueueSize = 256
|
opsErrorLogMinQueueSize = 256
|
||||||
opsErrorLogMaxQueueSize = 8192
|
opsErrorLogMaxQueueSize = 8192
|
||||||
|
opsErrorLogBatchSize = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
type opsErrorLogJob struct {
|
type opsErrorLogJob struct {
|
||||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
|||||||
for i := 0; i < workerCount; i++ {
|
for i := 0; i < workerCount; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer opsErrorLogWorkersWg.Done()
|
defer opsErrorLogWorkersWg.Done()
|
||||||
for job := range opsErrorLogQueue {
|
for {
|
||||||
opsErrorLogQueueLen.Add(-1)
|
job, ok := <-opsErrorLogQueue
|
||||||
if job.ops == nil || job.entry == nil {
|
if !ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
func() {
|
opsErrorLogQueueLen.Add(-1)
|
||||||
defer func() {
|
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||||
if r := recover(); r != nil {
|
batch = append(batch, job)
|
||||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
|
||||||
|
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||||
|
batchLoop:
|
||||||
|
for len(batch) < opsErrorLogBatchSize {
|
||||||
|
select {
|
||||||
|
case nextJob, ok := <-opsErrorLogQueue:
|
||||||
|
if !ok {
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
opsErrorLogQueueLen.Add(-1)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
batch = append(batch, nextJob)
|
||||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
case <-timer.C:
|
||||||
cancel()
|
break batchLoop
|
||||||
opsErrorLogProcessed.Add(1)
|
}
|
||||||
}()
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||||
|
var processed int64
|
||||||
|
for _, job := range batch {
|
||||||
|
if job.ops == nil || job.entry == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||||
|
processed++
|
||||||
|
}
|
||||||
|
if processed == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for opsSvc, entries := range grouped {
|
||||||
|
if opsSvc == nil || len(entries) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||||
|
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
opsErrorLogProcessed.Add(processed)
|
||||||
|
}
|
||||||
|
|
||||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||||
if ops == nil || entry == nil {
|
if ops == nil || entry == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
|||||||
// Antigravity 支持的 Gemini 模型
|
// Antigravity 支持的 Gemini 模型
|
||||||
var geminiModels = []modelDef{
|
var geminiModels = []modelDef{
|
||||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
|||||||
|
|
||||||
requiredIDs := []string{
|
requiredIDs := []string{
|
||||||
"claude-opus-4-6-thinking",
|
"claude-opus-4-6-thinking",
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview",
|
||||||
"gemini-3.1-flash-image",
|
"gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview",
|
"gemini-3.1-flash-image-preview",
|
||||||
"gemini-3-pro-image", // legacy compatibility
|
"gemini-3-pro-image", // legacy compatibility
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||||
|
type UsageMapHook func(usageMap map[string]any)
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
|
usageMapHook UsageMapHook
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||||
|
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||||
|
p.usageMapHook = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func usageToMap(u ClaudeUsage) map[string]any {
|
||||||
|
m := map[string]any{
|
||||||
|
"input_tokens": u.InputTokens,
|
||||||
|
"output_tokens": u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.CacheCreationInputTokens > 0 {
|
||||||
|
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||||
|
}
|
||||||
|
if u.CacheReadInputTokens > 0 {
|
||||||
|
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
@@ -168,6 +191,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
responseID = "msg_" + generateRandomID()
|
responseID = "msg_" + generateRandomID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
message := map[string]any{
|
message := map[string]any{
|
||||||
"id": responseID,
|
"id": responseID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
@@ -176,7 +206,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
"model": p.originalModel,
|
"model": p.originalModel,
|
||||||
"stop_reason": nil,
|
"stop_reason": nil,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
event := map[string]any{
|
event := map[string]any{
|
||||||
@@ -487,13 +517,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
CacheReadInputTokens: p.cacheReadTokens,
|
CacheReadInputTokens: p.cacheReadTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
deltaEvent := map[string]any{
|
deltaEvent := map[string]any{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": map[string]any{
|
"delta": map[string]any{
|
||||||
"stop_reason": stopReason,
|
"stop_reason": stopReason,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
},
|
},
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||||
|
|||||||
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,733 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ChatCompletionsToResponses tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4o", resp.Model)
|
||||||
|
assert.True(t, resp.Stream) // always forced true
|
||||||
|
assert.False(t, *resp.Store)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "system", items[0].Role)
|
||||||
|
assert.Equal(t, "user", items[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "ping",
|
||||||
|
Arguments: `{"host":"example.com"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallID: "call_1",
|
||||||
|
Content: json.RawMessage(`"pong"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []ChatTool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: &ChatFunction{
|
||||||
|
Name: "ping",
|
||||||
|
Description: "Ping a host",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + function_call + function_call_output = 3
|
||||||
|
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
|
||||||
|
// Check function_call item
|
||||||
|
assert.Equal(t, "function_call", items[1].Type)
|
||||||
|
assert.Equal(t, "call_1", items[1].CallID)
|
||||||
|
assert.Equal(t, "ping", items[1].Name)
|
||||||
|
|
||||||
|
// Check function_call_output item
|
||||||
|
assert.Equal(t, "function_call_output", items[2].Type)
|
||||||
|
assert.Equal(t, "call_1", items[2].CallID)
|
||||||
|
assert.Equal(t, "pong", items[2].Output)
|
||||||
|
|
||||||
|
// Check tools
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||||
|
t.Run("max_tokens", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
// Below minMaxOutputTokens (128), should be clamped
|
||||||
|
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
maxCompletion := 500
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
MaxCompletionTokens: &maxCompletion,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ReasoningEffort: "high",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.Reasoning)
|
||||||
|
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||||
|
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||||
|
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(content)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||||
|
require.Len(t, parts, 2)
|
||||||
|
assert.Equal(t, "input_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "Describe this", parts[0].Text)
|
||||||
|
assert.Equal(t, "input_image", parts[1].Type)
|
||||||
|
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
Functions: []ChatFunction{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// tool_choice should be converted
|
||||||
|
require.NotNil(t, resp.ToolChoice)
|
||||||
|
var tc map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
|
assert.Equal(t, "function", tc["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ServiceTier: "flex",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "flex", resp.ServiceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: json.RawMessage(`"Let me call a function."`),
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "do_thing",
|
||||||
|
Arguments: `{}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + assistant message (with text) + function_call
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ResponsesToChatCompletions tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_123",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "Hello, world!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
assert.Equal(t, "chat.completion", chat.Object)
|
||||||
|
assert.Equal(t, "gpt-4o", chat.Model)
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "Hello, world!", content)
|
||||||
|
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_456",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_xyz",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"city":"NYC"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
msg := chat.Choices[0].Message
|
||||||
|
require.Len(t, msg.ToolCalls, 1)
|
||||||
|
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||||
|
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||||
|
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_789",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "reasoning",
|
||||||
|
Summary: []ResponsesSummary{
|
||||||
|
{Type: "summary_text", Text: "I thought about it."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "The answer is 42."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
// Reasoning summary is prepended to text
|
||||||
|
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_inc",
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "partial..."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_cache",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 10,
|
||||||
|
TotalTokens: 110,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_ws",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "web_search_call",
|
||||||
|
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "search results", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesEventToChatChunks tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
// response.created → role chunk
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
ID: "resp_stream",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||||
|
assert.True(t, state.SentRole)
|
||||||
|
|
||||||
|
// response.output_text.delta → content chunk
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "Hello",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||||
|
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
assert.Equal(t, "call_1", tc.ID)
|
||||||
|
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index)
|
||||||
|
|
||||||
|
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||||
|
Delta: `{"city":`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||||
|
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||||
|
|
||||||
|
// Add a second function call at output_index=2
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_2",
|
||||||
|
Name: "get_time",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||||
|
|
||||||
|
// Argument delta for second tool call
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Delta: `{"tz":"UTC"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||||
|
|
||||||
|
// Argument delta for first tool call (interleaved)
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Delta: `"Tokyo"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalTokens: 70,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
// finish chunk + usage chunk
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// First chunk: finish_reason
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Second chunk: usage
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||||
|
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SawToolCall = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "Thinking...",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
state.Usage = &ChatUsage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := FinalizeResponsesChatStream(state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||||
|
|
||||||
|
// Idempotent: second call returns nil
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||||
|
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||||
|
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
// Simulate response.completed
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||||
|
|
||||||
|
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatChunkToSSE(t *testing.T) {
|
||||||
|
chunk := ChatCompletionsChunk{
|
||||||
|
ID: "chatcmpl-test",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1700000000,
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Choices: []ChatChunkChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Role: "assistant"},
|
||||||
|
FinishReason: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sse, err := ChatChunkToSSE(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, sse, "data: ")
|
||||||
|
assert.Contains(t, sse, "chatcmpl-test")
|
||||||
|
assert.Contains(t, sse, "assistant")
|
||||||
|
assert.True(t, len(sse) > 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stream round-trip test
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||||
|
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||||
|
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||||
|
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
var allChunks []ChatCompletionsChunk
|
||||||
|
|
||||||
|
// 1. response.created
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// 2. text deltas
|
||||||
|
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: text,
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. response.completed
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 4,
|
||||||
|
TotalTokens: 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||||
|
require.Len(t, allChunks, 7)
|
||||||
|
|
||||||
|
// First chunk has role
|
||||||
|
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||||
|
|
||||||
|
// Text chunks
|
||||||
|
var fullText string
|
||||||
|
for i := 1; i <= 4; i++ {
|
||||||
|
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||||
|
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
assert.Equal(t, "Hello, world!", fullText)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, allChunks[6].Usage)
|
||||||
|
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||||
|
|
||||||
|
// All chunks share the same ID
|
||||||
|
for _, c := range allChunks {
|
||||||
|
assert.Equal(t, "resp_rt", c.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||||
|
// Responses API request. The upstream always streams, so Stream is forced to
|
||||||
|
// true. store is always false and reasoning.encrypted_content is always
|
||||||
|
// included so that the response translator has full context.
|
||||||
|
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||||
|
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputJSON, err := json.Marshal(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ResponsesRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Input: inputJSON,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: true, // upstream always streams
|
||||||
|
Include: []string{"reasoning.encrypted_content"},
|
||||||
|
ServiceTier: req.ServiceTier,
|
||||||
|
}
|
||||||
|
|
||||||
|
storeFalse := false
|
||||||
|
out.Store = &storeFalse
|
||||||
|
|
||||||
|
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||||
|
maxTokens := 0
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
maxTokens = *req.MaxTokens
|
||||||
|
}
|
||||||
|
if req.MaxCompletionTokens != nil {
|
||||||
|
maxTokens = *req.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
if maxTokens > 0 {
|
||||||
|
v := maxTokens
|
||||||
|
if v < minMaxOutputTokens {
|
||||||
|
v = minMaxOutputTokens
|
||||||
|
}
|
||||||
|
out.MaxOutputTokens = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||||
|
if req.ReasoningEffort != "" {
|
||||||
|
out.Reasoning = &ResponsesReasoning{
|
||||||
|
Effort: req.ReasoningEffort,
|
||||||
|
Summary: "auto",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools[] and legacy functions[] → ResponsesTool[]
|
||||||
|
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||||
|
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tool_choice: already compatible format — pass through directly.
|
||||||
|
// Legacy function_call needs mapping.
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
out.ToolChoice = req.ToolChoice
|
||||||
|
} else if len(req.FunctionCall) > 0 {
|
||||||
|
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||||
|
}
|
||||||
|
out.ToolChoice = tc
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||||
|
// array into a Responses API input items array.
|
||||||
|
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var out []ResponsesInputItem
|
||||||
|
for _, m := range msgs {
|
||||||
|
items, err := chatMessageToResponsesItems(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, items...)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||||
|
// ResponsesInputItem values.
|
||||||
|
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
switch m.Role {
|
||||||
|
case "system":
|
||||||
|
return chatSystemToResponses(m)
|
||||||
|
case "user":
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
case "assistant":
|
||||||
|
return chatAssistantToResponses(m)
|
||||||
|
case "tool":
|
||||||
|
return chatToolToResponses(m)
|
||||||
|
case "function":
|
||||||
|
return chatFunctionToResponses(m)
|
||||||
|
default:
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatSystemToResponses converts a system message.
|
||||||
|
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
text, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
content, err := json.Marshal(text)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatUserToResponses converts a user message, handling both plain strings and
|
||||||
|
// multi-modal content arrays.
|
||||||
|
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
// Try plain string first.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||||
|
content, _ := json.Marshal(s)
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []ChatContentPart
|
||||||
|
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse user content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseParts []ResponsesContentPart
|
||||||
|
for _, p := range parts {
|
||||||
|
switch p.Type {
|
||||||
|
case "text":
|
||||||
|
if p.Text != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_text",
|
||||||
|
Text: p.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_image",
|
||||||
|
ImageURL: p.ImageURL.URL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := json.Marshal(responseParts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatAssistantToResponses converts an assistant message. If there is both
|
||||||
|
// text content and tool_calls, the text is emitted as an assistant message
|
||||||
|
// first, then each tool_call becomes a function_call item. If the content is
|
||||||
|
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||||
|
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
|
||||||
|
// Emit assistant message with output_text if content is non-empty.
|
||||||
|
if len(m.Content) > 0 {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||||
|
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||||
|
partsJSON, err := json.Marshal(parts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit one function_call item per tool_call.
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
args := tc.Function.Arguments
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: args,
|
||||||
|
ID: tc.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||||
|
// function_call_output item.
|
||||||
|
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.ToolCallID,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatFunctionToResponses converts a legacy function result message
|
||||||
|
// (role=function) into a function_call_output item. The Name field is used as
|
||||||
|
// call_id since legacy function calls do not carry a separate call_id.
|
||||||
|
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.Name,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||||
|
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||||
|
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err != nil {
|
||||||
|
return "", fmt.Errorf("parse content as string: %w", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||||
|
// function definitions to Responses API tool definitions.
|
||||||
|
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||||
|
var out []ResponsesTool
|
||||||
|
|
||||||
|
for _, t := range tools {
|
||||||
|
if t.Type != "function" || t.Function == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Description: t.Function.Description,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: t.Function.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy functions[] are treated as function-type tools.
|
||||||
|
for _, f := range functions {
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: f.Name,
|
||||||
|
Description: f.Description,
|
||||||
|
Parameters: f.Parameters,
|
||||||
|
Strict: f.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||||
|
// Responses API tool_choice value.
|
||||||
|
//
|
||||||
|
// "auto" → "auto"
|
||||||
|
// "none" → "none"
|
||||||
|
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||||
|
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
|
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Object form: {"name":"X"}
|
||||||
|
var obj struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]string{"name": obj.Name},
|
||||||
|
})
|
||||||
|
}
|
||||||
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||||
|
// Completions response. Text output items are concatenated into
|
||||||
|
// choices[0].message.content; function_call items become tool_calls.
|
||||||
|
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||||
|
id := resp.ID
|
||||||
|
if id == "" {
|
||||||
|
id = generateChatCmplID()
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ChatCompletionsResponse{
|
||||||
|
ID: id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
|
||||||
|
var contentText string
|
||||||
|
var toolCalls []ChatToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, part := range item.Content {
|
||||||
|
if part.Type == "output_text" && part.Text != "" {
|
||||||
|
contentText += part.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
toolCalls = append(toolCalls, ChatToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: item.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case "reasoning":
|
||||||
|
for _, s := range item.Summary {
|
||||||
|
if s.Type == "summary_text" && s.Text != "" {
|
||||||
|
contentText += s.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "web_search_call":
|
||||||
|
// silently consumed — results already incorporated into text output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := ChatMessage{Role: "assistant"}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
msg.ToolCalls = toolCalls
|
||||||
|
}
|
||||||
|
if contentText != "" {
|
||||||
|
raw, _ := json.Marshal(contentText)
|
||||||
|
msg.Content = raw
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||||
|
|
||||||
|
out.Choices = []ChatChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: msg,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
}}
|
||||||
|
|
||||||
|
if resp.Usage != nil {
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: resp.Usage.InputTokens,
|
||||||
|
CompletionTokens: resp.Usage.OutputTokens,
|
||||||
|
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||||
|
switch status {
|
||||||
|
case "incomplete":
|
||||||
|
if details != nil && details.Reason == "max_output_tokens" {
|
||||||
|
return "length"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
case "completed":
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return "tool_calls"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
default:
|
||||||
|
return "stop"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||||
|
// SSE events into Chat Completions SSE chunks.
|
||||||
|
type ResponsesEventToChatState struct {
|
||||||
|
ID string
|
||||||
|
Model string
|
||||||
|
Created int64
|
||||||
|
SentRole bool
|
||||||
|
SawToolCall bool
|
||||||
|
SawText bool
|
||||||
|
Finalized bool // true after finish chunk has been emitted
|
||||||
|
NextToolCallIndex int // next sequential tool_call index to assign
|
||||||
|
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||||
|
IncludeUsage bool
|
||||||
|
Usage *ChatUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponsesEventToChatState returns an initialised stream state.
|
||||||
|
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||||
|
return &ResponsesEventToChatState{
|
||||||
|
ID: generateChatCmplID(),
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OutputIndexToToolIndex: make(map[int]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||||
|
// or more Chat Completions chunks, updating state as it goes.
|
||||||
|
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
switch evt.Type {
|
||||||
|
case "response.created":
|
||||||
|
return resToChatHandleCreated(evt, state)
|
||||||
|
case "response.output_text.delta":
|
||||||
|
return resToChatHandleTextDelta(evt, state)
|
||||||
|
case "response.output_item.added":
|
||||||
|
return resToChatHandleOutputItemAdded(evt, state)
|
||||||
|
case "response.function_call_arguments.delta":
|
||||||
|
return resToChatHandleFuncArgsDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.delta":
|
||||||
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
|
case "response.completed", "response.incomplete", "response.failed":
|
||||||
|
return resToChatHandleCompleted(evt, state)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||||
|
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||||
|
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||||
|
// this returns nil.
|
||||||
|
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if state.Finalized {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.Finalized = true
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||||
|
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||||
|
data, err := json.Marshal(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal handlers ---
|
||||||
|
|
||||||
|
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.ID != "" {
|
||||||
|
state.ID = evt.Response.ID
|
||||||
|
}
|
||||||
|
if state.Model == "" && evt.Response.Model != "" {
|
||||||
|
state.Model = evt.Response.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Emit the role chunk.
|
||||||
|
if state.SentRole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
role := "assistant"
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SawText = true
|
||||||
|
content := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state.SawToolCall = true
|
||||||
|
idx := state.NextToolCallIndex
|
||||||
|
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||||
|
state.NextToolCallIndex++
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
ID: evt.Item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: evt.Item.Name,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Arguments: evt.Delta,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
content := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
state.Finalized = true
|
||||||
|
finishReason := "stop"
|
||||||
|
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.Usage != nil {
|
||||||
|
u := evt.Response.Usage
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: u.InputTokens,
|
||||||
|
CompletionTokens: u.OutputTokens,
|
||||||
|
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
switch evt.Response.Status {
|
||||||
|
case "incomplete":
|
||||||
|
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
case "completed":
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []ChatCompletionsChunk
|
||||||
|
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: delta,
|
||||||
|
FinishReason: nil,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||||
|
empty := ""
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Content: &empty},
|
||||||
|
FinishReason: &finishReason,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||||
|
func generateChatCmplID() string {
|
||||||
|
b := make([]byte, 12)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return "chatcmpl-" + hex.EncodeToString(b)
|
||||||
|
}
|
||||||
@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
|
|||||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// OpenAI Chat Completions API types
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||||
|
Tools []ChatTool `json:"tools,omitempty"`
|
||||||
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||||
|
|
||||||
|
// Legacy function calling (deprecated but still supported)
|
||||||
|
Functions []ChatFunction `json:"functions,omitempty"`
|
||||||
|
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamOptions configures streaming behavior.
|
||||||
|
type ChatStreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage is a single message in the Chat Completions conversation.
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||||
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
|
||||||
|
// Legacy function calling
|
||||||
|
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatContentPart is a typed content part in a multi-modal message.
|
||||||
|
type ChatContentPart struct {
|
||||||
|
Type string `json:"type"` // "text" | "image_url"
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatImageURL contains the URL for an image content part.
|
||||||
|
type ChatImageURL struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTool describes a tool available to the model.
|
||||||
|
type ChatTool struct {
|
||||||
|
Type string `json:"type"` // "function"
|
||||||
|
Function *ChatFunction `json:"function,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunction describes a function tool definition.
|
||||||
|
type ChatFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Strict *bool `json:"strict,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatToolCall represents a tool call made by the assistant.
|
||||||
|
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||||
|
type ChatToolCall struct {
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"` // "function"
|
||||||
|
Function ChatFunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunctionCall contains the function name and arguments.
|
||||||
|
type ChatFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChoice is a single completion choice.
|
||||||
|
type ChatChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatUsage holds token counts in Chat Completions format.
|
||||||
|
type ChatUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTokenDetails provides a breakdown of token usage.
|
||||||
|
type ChatTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsChunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion.chunk"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChunkChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||||
|
type ChatChunkChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta ChatDelta `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatDelta carries incremental content in a streaming chunk.
|
||||||
|
type ChatDelta struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Shared constants
|
// Shared constants
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
|||||||
return []Model{
|
return []Model{
|
||||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
models := DefaultModels()
|
||||||
|
byName := make(map[string]Model, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
byName[model.Name] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"models/gemini-2.5-flash-image",
|
||||||
|
"models/gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range required {
|
||||||
|
model, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected fallback model %q to exist", name)
|
||||||
|
}
|
||||||
|
if len(model.SupportedGenerationMethods) == 0 {
|
||||||
|
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,10 +13,12 @@ type Model struct {
|
|||||||
var DefaultModels = []Model{
|
var DefaultModels = []Model{
|
||||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||||
|
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||||
|
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultTestModel is the default model to preselect in test flows.
|
// DefaultTestModel is the default model to preselect in test flows.
|
||||||
|
|||||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
byID := make(map[string]Model, len(DefaultModels))
|
||||||
|
for _, model := range DefaultModels {
|
||||||
|
byID[model.ID] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range required {
|
||||||
|
if _, ok := byID[id]; !ok {
|
||||||
|
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -50,6 +51,18 @@ type accountRepository struct {
|
|||||||
schedulerCache service.SchedulerCache
|
schedulerCache service.SchedulerCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||||
|
"codex_primary_",
|
||||||
|
"codex_secondary_",
|
||||||
|
"codex_5h_",
|
||||||
|
"codex_7d_",
|
||||||
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
|
"codex_usage_updated_at": {},
|
||||||
|
"session_window_utilization": {},
|
||||||
|
}
|
||||||
|
|
||||||
// NewAccountRepository 创建账户仓储实例。
|
// NewAccountRepository 创建账户仓储实例。
|
||||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||||
@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||||
|
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||||
|
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for key := range updates {
|
||||||
|
if isSchedulerNeutralExtraKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSchedulerNeutralExtraKey(key string) bool {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||||
|
if strings.HasPrefix(key, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
|||||||
|
|
||||||
type schedulerCacheRecorder struct {
|
type schedulerCacheRecorder struct {
|
||||||
setAccounts []*service.Account
|
setAccounts []*service.Account
|
||||||
|
accounts map[int64]*service.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||||
return nil, nil
|
if s.accounts == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.accounts[accountID], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||||
s.setAccounts = append(s.setAccounts, account)
|
s.setAccounts = append(s.setAccounts, account)
|
||||||
|
if s.accounts == nil {
|
||||||
|
s.accounts = make(map[int64]*service.Account)
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
s.accounts[account.ID] = account
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
|||||||
s.Require().Equal("val", got.Extra["key"])
|
s.Require().Equal("val", got.Extra["key"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-neutral",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{
|
||||||
|
accounts: map[int64]*service.Account{
|
||||||
|
account.ID: {
|
||||||
|
ID: account.ID,
|
||||||
|
Platform: account.Platform,
|
||||||
|
Status: service.StatusDisabled,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_usage_updated_at": "old",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
updates := map[string]any{
|
||||||
|
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||||
|
"codex_5h_used_percent": 88.5,
|
||||||
|
"session_window_utilization": 0.42,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||||
|
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||||
|
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||||
|
|
||||||
|
var outboxCount int
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||||
|
s.Require().Zero(outboxCount)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||||
|
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||||
|
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-codex-exhausted",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||||
|
"codex_7d_reset_after_seconds": 86400,
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(0, count)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||||
|
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-mixed",
|
||||||
|
Platform: service.PlatformAntigravity,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(1, count)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetByCRSAccountID ---
|
// --- GetByCRSAccountID ---
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldModelRoutingEnabled,
|
group.FieldModelRoutingEnabled,
|
||||||
group.FieldModelRouting,
|
group.FieldModelRouting,
|
||||||
group.FieldMcpXMLInject,
|
group.FieldMcpXMLInject,
|
||||||
|
group.FieldSimulateClaudeMaxEnabled,
|
||||||
group.FieldSupportedModelScopes,
|
group.FieldSupportedModelScopes,
|
||||||
group.FieldAllowMessagesDispatch,
|
group.FieldAllowMessagesDispatch,
|
||||||
group.FieldDefaultMappedModel,
|
group.FieldDefaultMappedModel,
|
||||||
@@ -452,6 +453,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
|||||||
return updated.QuotaUsed, nil
|
return updated.QuotaUsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||||
|
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||||
|
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||||
|
query := `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET
|
||||||
|
quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $3 AND deleted_at IS NULL
|
||||||
|
RETURNING quota_used, quota, key, status
|
||||||
|
`
|
||||||
|
|
||||||
|
state := &service.APIKeyQuotaUsageState{}
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||||
affected, err := r.client.APIKey.Update().
|
affected, err := r.client.APIKey.Update().
|
||||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||||
@@ -619,6 +646,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.McpXMLInject,
|
MCPXMLInject: g.McpXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SortOrder: g.SortOrder,
|
SortOrder: g.SortOrder,
|
||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
|
|||||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
|||||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||||
|
user := s.mustCreateUser("quota-state@test.com")
|
||||||
|
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||||
|
key.Quota = 3
|
||||||
|
key.QuotaUsed = 1
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||||
|
|
||||||
|
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||||
|
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||||
|
s.Require().NotNil(state)
|
||||||
|
s.Require().Equal(3.5, state.QuotaUsed)
|
||||||
|
s.Require().Equal(3.0, state.Quota)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||||
|
s.Require().Equal(key.Key, state.Key)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(3.5, got.QuotaUsed)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||||
|
}
|
||||||
|
|
||||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||||
|
|||||||
@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||||
|
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
@@ -129,7 +130,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||||
|
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||||
|
|
||||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||||
if groupIn.DailyLimitUSD != nil {
|
if groupIn.DailyLimitUSD != nil {
|
||||||
|
|||||||
@@ -16,19 +16,7 @@ type opsRepository struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
const insertOpsErrorLogSQL = `
|
||||||
return &opsRepository{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
|
||||||
if r == nil || r.db == nil {
|
|
||||||
return 0, fmt.Errorf("nil ops repository")
|
|
||||||
}
|
|
||||||
if input == nil {
|
|
||||||
return 0, fmt.Errorf("nil input")
|
|
||||||
}
|
|
||||||
|
|
||||||
q := `
|
|
||||||
INSERT INTO ops_error_logs (
|
INSERT INTO ops_error_logs (
|
||||||
request_id,
|
request_id,
|
||||||
client_request_id,
|
client_request_id,
|
||||||
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
|
|||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||||
) RETURNING id`
|
)`
|
||||||
|
|
||||||
|
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||||
|
return &opsRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return 0, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
|
||||||
var id int64
|
var id int64
|
||||||
err := r.db.QueryRowContext(
|
err := r.db.QueryRowContext(
|
||||||
ctx,
|
ctx,
|
||||||
q,
|
insertOpsErrorLogSQL+" RETURNING id",
|
||||||
|
opsInsertErrorLogArgs(input)...,
|
||||||
|
).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = stmt.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var inserted int64
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
|
||||||
|
return inserted, err
|
||||||
|
}
|
||||||
|
inserted++
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return inserted, err
|
||||||
|
}
|
||||||
|
return inserted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||||
|
return []any{
|
||||||
opsNullString(input.RequestID),
|
opsNullString(input.RequestID),
|
||||||
opsNullString(input.ClientRequestID),
|
opsNullString(input.ClientRequestID),
|
||||||
opsNullInt64(input.UserID),
|
opsNullInt64(input.UserID),
|
||||||
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
|
|||||||
input.IsRetryable,
|
input.IsRetryable,
|
||||||
input.RetryCount,
|
input.RetryCount,
|
||||||
input.CreatedAt,
|
input.CreatedAt,
|
||||||
).Scan(&id)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
return id, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||||
|
|
||||||
|
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||||
|
{
|
||||||
|
RequestID: "batch-ops-1",
|
||||||
|
ErrorPhase: "upstream",
|
||||||
|
ErrorType: "upstream_error",
|
||||||
|
Severity: "error",
|
||||||
|
StatusCode: 429,
|
||||||
|
ErrorMessage: "rate limited",
|
||||||
|
CreatedAt: now,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
RequestID: "batch-ops-2",
|
||||||
|
ErrorPhase: "internal",
|
||||||
|
ErrorType: "api_error",
|
||||||
|
Severity: "error",
|
||||||
|
StatusCode: 500,
|
||||||
|
ErrorMessage: "internal error",
|
||||||
|
CreatedAt: now.Add(time.Millisecond),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, inserted)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||||
|
|
||||||
|
accountID := int64(12345)
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||||
|
|
||||||
|
accountID := int64(67890)
|
||||||
|
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||||
|
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const schedulerOutboxDedupWindow = time.Second
|
||||||
|
|
||||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||||
return &schedulerOutboxRepository{db: db}
|
return &schedulerOutboxRepository{db: db}
|
||||||
}
|
}
|
||||||
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
|
|||||||
}
|
}
|
||||||
payloadArg = encoded
|
payloadArg = encoded
|
||||||
}
|
}
|
||||||
_, err := exec.ExecContext(ctx, `
|
query := `
|
||||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
`, eventType, accountID, groupID, payloadArg)
|
`
|
||||||
|
args := []any{eventType, accountID, groupID, payloadArg}
|
||||||
|
if schedulerOutboxEventSupportsDedup(eventType) {
|
||||||
|
query = `
|
||||||
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
|
SELECT $1, $2, $3, $4
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM scheduler_outbox
|
||||||
|
WHERE event_type = $1
|
||||||
|
AND account_id IS NOT DISTINCT FROM $2
|
||||||
|
AND group_id IS NOT DISTINCT FROM $3
|
||||||
|
AND created_at >= NOW() - make_interval(secs => $5)
|
||||||
|
)
|
||||||
|
`
|
||||||
|
args = append(args, schedulerOutboxDedupWindow.Seconds())
|
||||||
|
}
|
||||||
|
_, err := exec.ExecContext(ctx, query, args...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func schedulerOutboxEventSupportsDedup(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case service.SchedulerOutboxEventAccountChanged,
|
||||||
|
service.SchedulerOutboxEventGroupChanged,
|
||||||
|
service.SchedulerOutboxEventFullRebuild:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1873,7 +1873,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
|||||||
query := `
|
query := `
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(ul.group_id, 0) as group_id,
|
COALESCE(ul.group_id, 0) as group_id,
|
||||||
COALESCE(g.name, '') as group_name,
|
COALESCE(g.name, '(无分组)') as group_name,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||||
|
|||||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||||
|
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||||
|
query := `
|
||||||
|
SELECT ugr.user_id, u.email, ugr.rate_multiplier
|
||||||
|
FROM user_group_rate_multipliers ugr
|
||||||
|
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||||
|
WHERE ugr.group_id = $1
|
||||||
|
ORDER BY ugr.user_id
|
||||||
|
`
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var result []service.UserGroupRateEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var entry service.UserGroupRateEntry
|
||||||
|
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, entry)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
groups.PUT("/:id", h.Admin.Group.Update)
|
groups.PUT("/:id", h.Admin.Group.Update)
|
||||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||||
|
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -456,6 +457,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||||
|
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
|
||||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
|
|||||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
// OpenAI Chat Completions API
|
||||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
|
|||||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||||
|
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||||
|
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||||
|
|
||||||
// Antigravity 模型列表
|
// Antigravity 模型列表
|
||||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||||
|
|||||||
@@ -45,16 +45,23 @@ const (
|
|||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
type TestEvent struct {
|
type TestEvent struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
Data any `json:"data,omitempty"`
|
ImageURL string `json:"image_url,omitempty"`
|
||||||
Success bool `json:"success,omitempty"`
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Data any `json:"data,omitempty"`
|
||||||
|
Success bool `json:"success,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultGeminiTextTestPrompt = "hi"
|
||||||
|
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
|
||||||
|
)
|
||||||
|
|
||||||
// AccountTestService handles account testing operations
|
// AccountTestService handles account testing operations
|
||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
|||||||
// TestAccountConnection tests an account's connection by sending a test request
|
// TestAccountConnection tests an account's connection by sending a test request
|
||||||
// All account types use full Claude Code client characteristics, only auth header differs
|
// All account types use full Claude Code client characteristics, only auth header differs
|
||||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Get account
|
// Get account
|
||||||
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account.IsGemini() {
|
if account.IsGemini() {
|
||||||
return s.testGeminiAccountConnection(c, account, modelID)
|
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
return s.routeAntigravityTest(c, account, modelID)
|
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformSora {
|
if account.Platform == PlatformSora {
|
||||||
@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// testGeminiAccountConnection tests a Gemini account's connection
|
// testGeminiAccountConnection tests a Gemini account's connection
|
||||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Determine the model to use
|
// Determine the model to use
|
||||||
@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
// Create test payload (Gemini format)
|
// Create test payload (Gemini format)
|
||||||
payload := createGeminiTestPayload()
|
payload := createGeminiTestPayload(testModelID, prompt)
|
||||||
|
|
||||||
// Build request based on account type
|
// Build request based on account type
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
|||||||
|
|
||||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
if strings.HasPrefix(modelID, "gemini-") {
|
if strings.HasPrefix(modelID, "gemini-") {
|
||||||
return s.testGeminiAccountConnection(c, account, modelID)
|
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
return s.testClaudeAccountConnection(c, account, modelID)
|
return s.testClaudeAccountConnection(c, account, modelID)
|
||||||
}
|
}
|
||||||
@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
// createGeminiTestPayload creates a minimal test payload for Gemini API.
|
||||||
func createGeminiTestPayload() []byte {
|
// Image models use the image-generation path so the frontend can preview the returned image.
|
||||||
|
func createGeminiTestPayload(modelID string, prompt string) []byte {
|
||||||
|
if isImageGenerationModel(modelID) {
|
||||||
|
imagePrompt := strings.TrimSpace(prompt)
|
||||||
|
if imagePrompt == "" {
|
||||||
|
imagePrompt = defaultGeminiImageTestPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": []map[string]any{
|
||||||
|
{"text": imagePrompt},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"generationConfig": map[string]any{
|
||||||
|
"responseModalities": []string{"TEXT", "IMAGE"},
|
||||||
|
"imageConfig": map[string]any{
|
||||||
|
"aspectRatio": "1:1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bytes, _ := json.Marshal(payload)
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
textPrompt := strings.TrimSpace(prompt)
|
||||||
|
if textPrompt == "" {
|
||||||
|
textPrompt = defaultGeminiTextTestPrompt
|
||||||
|
}
|
||||||
|
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"contents": []map[string]any{
|
"contents": []map[string]any{
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": []map[string]any{
|
"parts": []map[string]any{
|
||||||
{"text": "hi"},
|
{"text": textPrompt},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
|||||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
}
|
}
|
||||||
|
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
|
||||||
|
mimeType, _ := inlineData["mimeType"].(string)
|
||||||
|
data, _ := inlineData["data"].(string)
|
||||||
|
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
|
||||||
|
s.sendEvent(c, TestEvent{
|
||||||
|
Type: "image",
|
||||||
|
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
|
||||||
|
MimeType: mimeType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
|
|||||||
ginCtx, _ := gin.CreateTestContext(w)
|
ginCtx, _ := gin.CreateTestContext(w)
|
||||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||||
|
|
||||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
|
||||||
|
|
||||||
finishedAt := time.Now()
|
finishedAt := time.Now()
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
|
|||||||
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
|
||||||
|
|
||||||
|
var parsed struct {
|
||||||
|
Contents []struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"contents"`
|
||||||
|
GenerationConfig struct {
|
||||||
|
ResponseModalities []string `json:"responseModalities"`
|
||||||
|
ImageConfig struct {
|
||||||
|
AspectRatio string `json:"aspectRatio"`
|
||||||
|
} `json:"imageConfig"`
|
||||||
|
} `json:"generationConfig"`
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, json.Unmarshal(payload, &parsed))
|
||||||
|
require.Len(t, parsed.Contents, 1)
|
||||||
|
require.Len(t, parsed.Contents[0].Parts, 1)
|
||||||
|
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
|
||||||
|
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
|
||||||
|
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
ctx, recorder := newSoraTestContext()
|
||||||
|
svc := &AccountTestService{}
|
||||||
|
|
||||||
|
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||||
|
|
||||||
|
err := svc.processGeminiStream(ctx, stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := recorder.Body.String()
|
||||||
|
require.Contains(t, body, "\"type\":\"content\"")
|
||||||
|
require.Contains(t, body, "\"text\":\"ok\"")
|
||||||
|
require.Contains(t, body, "\"type\":\"image\"")
|
||||||
|
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
|
||||||
|
require.Contains(t, body, "\"mime_type\":\"image/png\"")
|
||||||
|
}
|
||||||
@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
||||||
mergeAccountExtra(account, updates)
|
mergeAccountExtra(account, updates)
|
||||||
|
if resetAt != nil {
|
||||||
|
account.RateLimitResetAt = resetAt
|
||||||
|
}
|
||||||
if usage.UpdatedAt == nil {
|
if usage.UpdatedAt == nil {
|
||||||
usage.UpdatedAt = &now
|
usage.UpdatedAt = &now
|
||||||
}
|
}
|
||||||
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
||||||
if account == nil || !account.IsOAuth() {
|
if account == nil || !account.IsOAuth() {
|
||||||
return nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
accessToken := account.GetOpenAIAccessToken()
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("no access token available")
|
return nil, nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
modelID := openaipkg.DefaultTestModel
|
modelID := openaipkg.DefaultTestModel
|
||||||
payload := createOpenAITestPayload(modelID, true)
|
payload := createOpenAITestPayload(modelID, true)
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
||||||
}
|
}
|
||||||
req.Host = "chatgpt.com"
|
req.Host = "chatgpt.com"
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
|||||||
ResponseHeaderTimeout: 10 * time.Second,
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if len(updates) > 0 {
|
if len(updates) > 0 || resetAt != nil {
|
||||||
go func(accountID int64, updates map[string]any) {
|
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
return updates, resetAt, nil
|
||||||
defer updateCancel()
|
}
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
||||||
|
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(updates) == 0 && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer updateCancel()
|
||||||
|
if len(updates) > 0 {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}(account.ID, updates)
|
}
|
||||||
return updates, nil
|
if resetAt != nil {
|
||||||
|
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
||||||
|
if resp == nil {
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
return nil, nil
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
|
baseTime := time.Now()
|
||||||
|
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
||||||
|
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
||||||
|
if len(updates) > 0 {
|
||||||
|
return updates, resetAt, nil
|
||||||
|
}
|
||||||
|
return nil, resetAt, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||||
if resp == nil {
|
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
return nil, nil
|
return updates, err
|
||||||
}
|
|
||||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|
||||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
|
||||||
if len(updates) > 0 {
|
|
||||||
return updates, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||||
|
|||||||
@@ -1,11 +1,36 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type accountUsageCodexProbeRepo struct {
|
||||||
|
stubOpenAIAccountRepo
|
||||||
|
updateExtraCh chan map[string]any
|
||||||
|
rateLimitCh chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||||
|
if r.updateExtraCh != nil {
|
||||||
|
copied := make(map[string]any, len(updates))
|
||||||
|
for k, v := range updates {
|
||||||
|
copied[k] = v
|
||||||
|
}
|
||||||
|
r.updateExtraCh <- copied
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||||
|
if r.rateLimitCh != nil {
|
||||||
|
r.rateLimitCh <- resetAt
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
|
|||||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("x-codex-primary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||||
|
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||||
|
headers.Set("x-codex-secondary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||||
|
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||||
|
|
||||||
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(updates) == 0 {
|
||||||
|
t.Fatal("expected codex probe updates from 429 headers")
|
||||||
|
}
|
||||||
|
if resetAt == nil {
|
||||||
|
t.Fatal("expected resetAt from exhausted codex headers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repo := &accountUsageCodexProbeRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 1),
|
||||||
|
rateLimitCh: make(chan time.Time, 1),
|
||||||
|
}
|
||||||
|
svc := &AccountUsageService{accountRepo: repo}
|
||||||
|
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||||
|
|
||||||
|
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||||
|
}, &resetAt)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||||
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe extra persistence timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-repo.rateLimitCh:
|
||||||
|
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
||||||
|
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
|||||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||||
DeleteGroup(ctx context.Context, id int64) error
|
DeleteGroup(ctx context.Context, id int64) error
|
||||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||||
|
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// API Key management (admin)
|
// API Key management (admin)
|
||||||
@@ -138,9 +139,10 @@ type CreateGroupInput struct {
|
|||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled bool // 是否启用模型路由
|
ModelRoutingEnabled bool // 是否启用模型路由
|
||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
|
SimulateClaudeMaxEnabled *bool
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string
|
SupportedModelScopes []string
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -177,9 +179,10 @@ type UpdateGroupInput struct {
|
|||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
|
SimulateClaudeMaxEnabled *bool
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string
|
SupportedModelScopes *[]string
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -363,6 +366,10 @@ type ProxyExitInfoProber interface {
|
|||||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupExistenceBatchReader interface {
|
||||||
|
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
type proxyQualityTarget struct {
|
type proxyQualityTarget struct {
|
||||||
Target string
|
Target string
|
||||||
URL string
|
URL string
|
||||||
@@ -439,10 +446,6 @@ type userGroupRateBatchReader interface {
|
|||||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type groupExistenceBatchReader interface {
|
|
||||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
func NewAdminService(
|
func NewAdminService(
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
@@ -860,6 +863,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
if input.MCPXMLInject != nil {
|
if input.MCPXMLInject != nil {
|
||||||
mcpXMLInject = *input.MCPXMLInject
|
mcpXMLInject = *input.MCPXMLInject
|
||||||
}
|
}
|
||||||
|
simulateClaudeMaxEnabled := false
|
||||||
|
if input.SimulateClaudeMaxEnabled != nil {
|
||||||
|
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||||
|
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||||
|
}
|
||||||
|
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||||
|
}
|
||||||
|
|
||||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||||
var accountIDsToCopy []int64
|
var accountIDsToCopy []int64
|
||||||
@@ -916,6 +926,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||||
ModelRouting: input.ModelRouting,
|
ModelRouting: input.ModelRouting,
|
||||||
MCPXMLInject: mcpXMLInject,
|
MCPXMLInject: mcpXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: input.SupportedModelScopes,
|
SupportedModelScopes: input.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||||
@@ -1127,6 +1138,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.MCPXMLInject != nil {
|
if input.MCPXMLInject != nil {
|
||||||
group.MCPXMLInject = *input.MCPXMLInject
|
group.MCPXMLInject = *input.MCPXMLInject
|
||||||
}
|
}
|
||||||
|
if input.SimulateClaudeMaxEnabled != nil {
|
||||||
|
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||||
|
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||||
|
}
|
||||||
|
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||||
|
}
|
||||||
|
if group.Platform != PlatformAnthropic {
|
||||||
|
group.SimulateClaudeMaxEnabled = false
|
||||||
|
}
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
if input.SupportedModelScopes != nil {
|
if input.SupportedModelScopes != nil {
|
||||||
@@ -1244,6 +1264,13 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return keys, result.Total, nil
|
return keys, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||||
|
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||||
s.getByIDsCalled = true
|
s.getByIDsCalled = true
|
||||||
s.getByIDsIDs = append([]int64{}, ids...)
|
s.getByIDsIDs = append([]int64{}, ids...)
|
||||||
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
|||||||
return nil, errors.New("account not found")
|
return nil, errors.New("account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
|
||||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||||
repo := &accountRepoStubForBulkUpdate{}
|
repo := &accountRepoStubForBulkUpdate{}
|
||||||
|
|||||||
@@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
|||||||
require.NotNil(t, repo.updated)
|
require.NotNil(t, repo.updated)
|
||||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||||
|
repo := &groupRepoStubForAdmin{}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "openai-group",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
SimulateClaudeMaxEnabled: &enabled,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||||
|
existingGroup := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-group",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||||
|
SimulateClaudeMaxEnabled: &enabled,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
|
||||||
|
existingGroup := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "anthropic-group",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
SimulateClaudeMaxEnabled: true,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
|
||||||
|
}
|
||||||
|
|||||||
@@ -68,6 +68,10 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
|||||||
panic("unexpected SyncUserGroupRates call")
|
panic("unexpected SyncUserGroupRates call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
|
panic("unexpected GetByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||||
panic("unexpected DeleteByGroupID call")
|
panic("unexpected DeleteByGroupID call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1673,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
// 客户端要求流式,直接透传转换
|
// 客户端要求流式,直接透传转换
|
||||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1683,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
clientDisconnect = streamRes.clientDisconnect
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
// 客户端要求非流式,收集流式响应后转换返回
|
// 客户端要求非流式,收集流式响应后转换返回
|
||||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1692,6 +1692,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
firstTokenMs = streamRes.firstTokenMs
|
firstTokenMs = streamRes.firstTokenMs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
|
||||||
|
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
@@ -2164,6 +2167,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
|
||||||
|
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
|
||||||
|
signatureCheckBody := respBody
|
||||||
|
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
|
||||||
|
signatureCheckBody = unwrapped
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusBadRequest &&
|
||||||
|
s.settingService != nil &&
|
||||||
|
s.settingService.IsSignatureRectifierEnabled(ctx) &&
|
||||||
|
isSignatureRelatedError(signatureCheckBody) &&
|
||||||
|
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
|
||||||
|
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
|
||||||
|
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||||
|
|
||||||
|
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||||
|
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||||
|
if wrapErr == nil {
|
||||||
|
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: ctx,
|
||||||
|
prefix: prefix,
|
||||||
|
account: account,
|
||||||
|
proxyURL: proxyURL,
|
||||||
|
accessToken: accessToken,
|
||||||
|
action: upstreamAction,
|
||||||
|
body: retryWrappedBody,
|
||||||
|
c: c,
|
||||||
|
httpUpstream: s.httpUpstream,
|
||||||
|
settingService: s.settingService,
|
||||||
|
accountRepo: s.accountRepo,
|
||||||
|
handleError: s.handleUpstreamError,
|
||||||
|
requestedModel: originalModel,
|
||||||
|
isStickySession: isStickySession,
|
||||||
|
groupID: 0,
|
||||||
|
sessionHash: "",
|
||||||
|
})
|
||||||
|
if retryErr == nil {
|
||||||
|
retryResp := retryResult.resp
|
||||||
|
if retryResp.StatusCode < 400 {
|
||||||
|
resp = retryResp
|
||||||
|
} else {
|
||||||
|
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||||
|
_ = retryResp.Body.Close()
|
||||||
|
retryOpsBody := retryRespBody
|
||||||
|
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
|
||||||
|
retryOpsBody = retryUnwrapped
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_retry",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
|
||||||
|
Detail: s.getUpstreamErrorDetail(retryOpsBody),
|
||||||
|
})
|
||||||
|
respBody = retryRespBody
|
||||||
|
resp = &http.Response{
|
||||||
|
StatusCode: retryResp.StatusCode,
|
||||||
|
Header: retryResp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
|
||||||
|
}
|
||||||
|
contentType = resp.Header.Get("Content-Type")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: http.StatusServiceUnavailable,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
ForceCacheBilling: switchErr.IsStickySession,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "signature_retry_request_error",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
|
})
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// fallback 成功:继续按正常响应处理
|
// fallback 成功:继续按正常响应处理
|
||||||
if resp.StatusCode < 400 {
|
if resp.StatusCode < 400 {
|
||||||
goto handleSuccess
|
goto handleSuccess
|
||||||
@@ -3489,7 +3598,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
|||||||
|
|
||||||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||||||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||||
@@ -3647,6 +3756,9 @@ returnResponse:
|
|||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Max cache billing simulation (non-streaming)
|
||||||
|
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
|
||||||
|
|
||||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||||
|
|
||||||
// 转换为 service.ClaudeUsage
|
// 转换为 service.ClaudeUsage
|
||||||
@@ -3661,7 +3773,7 @@ returnResponse:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
@@ -3674,6 +3786,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||||
|
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
|
||||||
|
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|||||||
@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
|||||||
return s.resp, s.err
|
return s.resp, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type queuedHTTPUpstreamStub struct {
|
||||||
|
responses []*http.Response
|
||||||
|
errors []error
|
||||||
|
requestBodies [][]byte
|
||||||
|
callCount int
|
||||||
|
onCall func(*http.Request, *queuedHTTPUpstreamStub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
if req != nil && req.Body != nil {
|
||||||
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
s.requestBodies = append(s.requestBodies, body)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
} else {
|
||||||
|
s.requestBodies = append(s.requestBodies, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := s.callCount
|
||||||
|
s.callCount++
|
||||||
|
if s.onCall != nil {
|
||||||
|
s.onCall(req, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
if idx < len(s.responses) {
|
||||||
|
resp = s.responses[idx]
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if idx < len(s.errors) {
|
||||||
|
err = s.errors[idx]
|
||||||
|
}
|
||||||
|
if resp == nil && err == nil {
|
||||||
|
return nil, errors.New("unexpected upstream call")
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
|
||||||
|
return s.Do(req, proxyURL, accountID, concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
type antigravitySettingRepoStub struct{}
|
type antigravitySettingRepoStub struct{}
|
||||||
|
|
||||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
|||||||
require.Equal(t, mappedModel, result.Model)
|
require.Equal(t, mappedModel, result.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||||
|
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req-sig-1"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req-sig-2"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
const originalModel = "gemini-3.1-pro-preview"
|
||||||
|
const mappedModel = "gemini-3.1-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "acc-gemini-signature",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
originalModel: mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||||
|
|
||||||
|
firstReq := string(upstream.requestBodies[0])
|
||||||
|
secondReq := string(upstream.requestBodies[1])
|
||||||
|
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
|
||||||
|
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
|
||||||
|
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
|
||||||
|
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
|
||||||
|
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, events)
|
||||||
|
require.Equal(t, "signature_error", events[0].Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||||
|
|
||||||
|
const originalModel = "gemini-3.1-pro-preview"
|
||||||
|
const mappedModel = "gemini-3.1-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 8,
|
||||||
|
Name: "acc-gemini-signature-failover",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
originalModel: mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req-sig-failover-1"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
|
||||||
|
if stub.callCount != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||||||
|
account.Extra = map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
mappedModel: map[string]any{
|
||||||
|
"rate_limit_reset_at": futureResetAt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
|
||||||
|
require.Nil(t, result)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||||
|
require.True(t, failoverErr.ForceCacheBilling)
|
||||||
|
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
require.Equal(t, "signature_error", events[0].Kind)
|
||||||
|
require.Equal(t, "failover", events[1].Kind)
|
||||||
|
}
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||||
@@ -710,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -787,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -990,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1022,7 +1234,7 @@ func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
||||||
@@ -1054,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
|||||||
|
|
||||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
|
|||||||
@@ -59,9 +59,10 @@ type APIKeyAuthGroupSnapshot struct {
|
|||||||
|
|
||||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
// Only anthropic groups use these fields; others may leave them empty.
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
|
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
ModelRouting: apiKey.Group.ModelRouting,
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||||
@@ -303,6 +304,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
ModelRouting: snapshot.Group.ModelRouting,
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||||
|
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
|
||||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
|||||||
return d.Usage7d
|
return d.Usage7d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
|
||||||
|
// It is intentionally small so repositories can return it from a single SQL statement.
|
||||||
|
type APIKeyQuotaUsageState struct {
|
||||||
|
QuotaUsed float64
|
||||||
|
Quota float64
|
||||||
|
Key string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
type APIKeyCache interface {
|
type APIKeyCache interface {
|
||||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type quotaStateReader interface {
|
||||||
|
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
|
||||||
|
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("increment quota used: %w", err)
|
||||||
|
}
|
||||||
|
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, state.Key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Use repository to atomically increment quota_used
|
// Use repository to atomically increment quota_used
|
||||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
170
backend/internal/service/api_key_service_quota_test.go
Normal file
170
backend/internal/service/api_key_service_quota_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type quotaStateRepoStub struct {
|
||||||
|
quotaBaseAPIKeyRepoStub
|
||||||
|
stateCalls int
|
||||||
|
state *APIKeyQuotaUsageState
|
||||||
|
stateErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
|
||||||
|
s.stateCalls++
|
||||||
|
if s.stateErr != nil {
|
||||||
|
return nil, s.stateErr
|
||||||
|
}
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := *s.state
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaStateCacheStub struct {
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaBaseAPIKeyRepoStub struct {
|
||||||
|
getByIDCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
|
||||||
|
s.getByIDCalls++
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||||
|
panic("unexpected IncrementQuotaUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||||
|
panic("unexpected UpdateLastUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||||
|
panic("unexpected IncrementRateLimitUsage call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
|
||||||
|
panic("unexpected ResetRateLimitWindows call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||||
|
panic("unexpected GetRateLimitData call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
|
||||||
|
repo := "aStateRepoStub{
|
||||||
|
state: &APIKeyQuotaUsageState{
|
||||||
|
QuotaUsed: 12,
|
||||||
|
Quota: 10,
|
||||||
|
Key: "sk-test-quota",
|
||||||
|
Status: StatusAPIKeyQuotaExhausted,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := "aStateCacheStub{}
|
||||||
|
svc := &APIKeyService{
|
||||||
|
apiKeyRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.stateCalls)
|
||||||
|
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
|
||||||
|
}
|
||||||
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeMaxCacheBillingOutcome struct {
|
||||||
|
Simulated bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||||
|
var out claudeMaxCacheBillingOutcome
|
||||||
|
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedModel := strings.TrimSpace(model)
|
||||||
|
if resolvedModel == "" && parsed != nil {
|
||||||
|
resolvedModel = strings.TrimSpace(parsed.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasCacheCreationTokens(*usage) {
|
||||||
|
// Upstream already returned cache creation usage; keep original usage.
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
beforeInputTokens := usage.InputTokens
|
||||||
|
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
|
||||||
|
if out.Simulated {
|
||||||
|
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||||
|
resolvedModel,
|
||||||
|
accountID,
|
||||||
|
beforeInputTokens,
|
||||||
|
usage.InputTokens,
|
||||||
|
usage.CacheCreation1hTokens,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClaudeFamilyModel(model string) bool {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||||
|
if normalized == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(normalized, "claude-")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
|
||||||
|
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
|
||||||
|
if group == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedModel := model
|
||||||
|
if resolvedModel == "" && parsed != nil {
|
||||||
|
resolvedModel = parsed.Model
|
||||||
|
}
|
||||||
|
if !isClaudeFamilyModel(resolvedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasCacheCreationTokens(usage ClaudeUsage) bool {
|
||||||
|
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||||
|
if input == nil || input.Result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !shouldApplyClaudeMaxBillingRules(input) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
||||||
|
if usage.InputTokens <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasCacheCreationTokens(usage) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !hasClaudeCacheSignals(parsed) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
||||||
|
changed = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return projectUsageToClaudeMax1H(usage, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||||
|
if usage == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||||
|
if totalWindowTokens <= 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
|
||||||
|
if simulatedInputTokens <= 0 {
|
||||||
|
simulatedInputTokens = 1
|
||||||
|
}
|
||||||
|
if simulatedInputTokens >= totalWindowTokens {
|
||||||
|
simulatedInputTokens = totalWindowTokens - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||||
|
if usage.InputTokens == simulatedInputTokens &&
|
||||||
|
usage.CacheCreation5mTokens == 0 &&
|
||||||
|
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||||
|
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.InputTokens = simulatedInputTokens
|
||||||
|
usage.CacheCreation5mTokens = 0
|
||||||
|
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||||
|
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type claudeCacheProjection struct {
|
||||||
|
HasBreakpoint bool
|
||||||
|
BreakpointCount int
|
||||||
|
TotalEstimatedTokens int
|
||||||
|
TailEstimatedTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||||
|
if totalWindowTokens <= 1 {
|
||||||
|
return totalWindowTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
projection := analyzeClaudeCacheProjection(parsed)
|
||||||
|
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
|
||||||
|
return totalWindowTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
totalEstimate := int64(projection.TotalEstimatedTokens)
|
||||||
|
tailEstimate := int64(projection.TailEstimatedTokens)
|
||||||
|
if tailEstimate > totalEstimate {
|
||||||
|
tailEstimate = totalEstimate
|
||||||
|
}
|
||||||
|
|
||||||
|
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
|
||||||
|
if scaled <= 0 {
|
||||||
|
scaled = 1
|
||||||
|
}
|
||||||
|
if scaled >= int64(totalWindowTokens) {
|
||||||
|
scaled = int64(totalWindowTokens - 1)
|
||||||
|
}
|
||||||
|
return int(scaled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
|
||||||
|
if parsed == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return countExplicitCacheBreakpoints(parsed) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
|
||||||
|
if parsed == nil || len(parsed.Body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
|
||||||
|
return strings.EqualFold(cacheType, "ephemeral")
|
||||||
|
}
|
||||||
|
|
||||||
|
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
|
||||||
|
var projection claudeCacheProjection
|
||||||
|
if parsed == nil {
|
||||||
|
return projection
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
lastBreakpointAt := -1
|
||||||
|
|
||||||
|
switch system := parsed.System.(type) {
|
||||||
|
case string:
|
||||||
|
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
|
||||||
|
case []any:
|
||||||
|
for _, raw := range system {
|
||||||
|
block, ok := raw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
total += claudeMaxUnknownContentTokens
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
total += estimateClaudeBlockTokens(block)
|
||||||
|
if hasEphemeralCacheControl(block) {
|
||||||
|
lastBreakpointAt = total
|
||||||
|
projection.BreakpointCount++
|
||||||
|
projection.HasBreakpoint = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rawMsg := range parsed.Messages {
|
||||||
|
total += claudeMaxMessageOverheadTokens
|
||||||
|
msg, ok := rawMsg.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
total += claudeMaxUnknownContentTokens
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, exists := msg["content"]
|
||||||
|
if !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
|
||||||
|
total += msgTokens
|
||||||
|
if msgBreakCount > 0 {
|
||||||
|
lastBreakpointAt = total - msgTokens + msgLastBreak
|
||||||
|
projection.BreakpointCount += msgBreakCount
|
||||||
|
projection.HasBreakpoint = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if total <= 0 {
|
||||||
|
total = 1
|
||||||
|
}
|
||||||
|
projection.TotalEstimatedTokens = total
|
||||||
|
|
||||||
|
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
|
||||||
|
tail := total - lastBreakpointAt
|
||||||
|
if tail <= 0 {
|
||||||
|
tail = 1
|
||||||
|
}
|
||||||
|
projection.TailEstimatedTokens = tail
|
||||||
|
return projection
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||||
|
tail := estimateLastUserMessageTokens(parsed)
|
||||||
|
if tail <= 0 {
|
||||||
|
tail = 1
|
||||||
|
}
|
||||||
|
projection.HasBreakpoint = true
|
||||||
|
projection.BreakpointCount = 1
|
||||||
|
projection.TailEstimatedTokens = tail
|
||||||
|
}
|
||||||
|
return projection
|
||||||
|
}
|
||||||
|
|
||||||
|
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
|
||||||
|
if parsed == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
total := 0
|
||||||
|
if system, ok := parsed.System.([]any); ok {
|
||||||
|
for _, raw := range system {
|
||||||
|
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||||
|
total++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, rawMsg := range parsed.Messages {
|
||||||
|
msg, ok := rawMsg.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := msg["content"].([]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, raw := range content {
|
||||||
|
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||||
|
total++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasEphemeralCacheControl(block map[string]any) bool {
|
||||||
|
if block == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw, ok := block["cache_control"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch cc := raw.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
cacheType, _ := cc["type"].(string)
|
||||||
|
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
|
||||||
|
case map[string]string:
|
||||||
|
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
|
||||||
|
switch value := content.(type) {
|
||||||
|
case string:
|
||||||
|
return estimateClaudeTextTokens(value), -1, 0
|
||||||
|
case []any:
|
||||||
|
total := 0
|
||||||
|
lastBreak := -1
|
||||||
|
breaks := 0
|
||||||
|
for _, raw := range value {
|
||||||
|
block, ok := raw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
total += claudeMaxUnknownContentTokens
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
total += estimateClaudeBlockTokens(block)
|
||||||
|
if hasEphemeralCacheControl(block) {
|
||||||
|
lastBreak = total
|
||||||
|
breaks++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, lastBreak, breaks
|
||||||
|
default:
|
||||||
|
return estimateStructuredTokens(value), -1, 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateClaudeBlockTokens(block map[string]any) int {
|
||||||
|
if block == nil {
|
||||||
|
return claudeMaxUnknownContentTokens
|
||||||
|
}
|
||||||
|
tokens := claudeMaxBlockOverheadTokens
|
||||||
|
blockType, _ := block["type"].(string)
|
||||||
|
switch blockType {
|
||||||
|
case "text":
|
||||||
|
if text, ok := block["text"].(string); ok {
|
||||||
|
tokens += estimateClaudeTextTokens(text)
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
if content, ok := block["content"]; ok {
|
||||||
|
nested, _, _ := estimateClaudeContentTokens(content)
|
||||||
|
tokens += nested
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
if name, ok := block["name"].(string); ok {
|
||||||
|
tokens += estimateClaudeTextTokens(name)
|
||||||
|
}
|
||||||
|
if input, ok := block["input"]; ok {
|
||||||
|
tokens += estimateStructuredTokens(input)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text, ok := block["text"].(string); ok {
|
||||||
|
tokens += estimateClaudeTextTokens(text)
|
||||||
|
} else if content, ok := block["content"]; ok {
|
||||||
|
nested, _, _ := estimateClaudeContentTokens(content)
|
||||||
|
tokens += nested
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tokens <= claudeMaxBlockOverheadTokens {
|
||||||
|
tokens += claudeMaxUnknownContentTokens
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
|
||||||
|
if parsed == nil || len(parsed.Messages) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||||
|
msg, ok := parsed.Messages[i].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
|
||||||
|
return claudeMaxMessageOverheadTokens + tokens
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateStructuredTokens(v any) int {
|
||||||
|
if v == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return claudeMaxUnknownContentTokens
|
||||||
|
}
|
||||||
|
return estimateClaudeTextTokens(string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateClaudeTextTokens(text string) int {
|
||||||
|
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
return estimateClaudeTextTokensHeuristic(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateClaudeTextTokensHeuristic(text string) int {
|
||||||
|
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||||
|
if normalized == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
asciiChars := 0
|
||||||
|
nonASCIIChars := 0
|
||||||
|
for _, r := range normalized {
|
||||||
|
if r <= 127 {
|
||||||
|
asciiChars++
|
||||||
|
} else {
|
||||||
|
nonASCIIChars++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokens := nonASCIIChars
|
||||||
|
if asciiChars > 0 {
|
||||||
|
tokens += (asciiChars + 3) / 4
|
||||||
|
}
|
||||||
|
if words := len(strings.Fields(normalized)); words > tokens {
|
||||||
|
tokens = words
|
||||||
|
}
|
||||||
|
if tokens <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
156
backend/internal/service/claude_max_simulation_test.go
Normal file
156
backend/internal/service/claude_max_simulation_test.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||||
|
usage := &ClaudeUsage{
|
||||||
|
InputTokens: 1200,
|
||||||
|
CacheCreationInputTokens: 0,
|
||||||
|
CacheCreation5mTokens: 0,
|
||||||
|
CacheCreation1hTokens: 0,
|
||||||
|
}
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": strings.Repeat("cached context ", 200),
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "summarize quickly",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := projectUsageToClaudeMax1H(usage, parsed)
|
||||||
|
if !changed {
|
||||||
|
t.Fatalf("expected usage to be projected")
|
||||||
|
}
|
||||||
|
|
||||||
|
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||||
|
if total != 1200 {
|
||||||
|
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
|
||||||
|
}
|
||||||
|
if usage.CacheCreation5mTokens != 0 {
|
||||||
|
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
|
||||||
|
}
|
||||||
|
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
|
||||||
|
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
|
||||||
|
}
|
||||||
|
if usage.InputTokens > 100 {
|
||||||
|
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
|
||||||
|
}
|
||||||
|
if usage.CacheCreation1hTokens <= 0 {
|
||||||
|
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
|
||||||
|
}
|
||||||
|
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
|
||||||
|
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-opus-4-5",
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "build context",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "what is failing now",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||||
|
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||||
|
if got1 != got2 {
|
||||||
|
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
|
||||||
|
group := &Group{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SimulateClaudeMaxEnabled: true,
|
||||||
|
}
|
||||||
|
input := &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 3000,
|
||||||
|
CacheCreationInputTokens: 0,
|
||||||
|
CacheCreation5mTokens: 0,
|
||||||
|
CacheCreation1hTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ParsedRequest: &ParsedRequest{
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "cached",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "tail",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{Group: group},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldSimulateClaudeMaxUsage(input) {
|
||||||
|
t.Fatalf("expected simulate=true for claude group with cache signal")
|
||||||
|
}
|
||||||
|
|
||||||
|
input.ParsedRequest = &ParsedRequest{
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{"role": "user", "content": "no cache signal"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if shouldSimulateClaudeMaxUsage(input) {
|
||||||
|
t.Fatalf("expected simulate=false when request has no cache signal")
|
||||||
|
}
|
||||||
|
|
||||||
|
input.ParsedRequest = &ParsedRequest{
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "cached",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input.Result.Usage.CacheCreationInputTokens = 100
|
||||||
|
if shouldSimulateClaudeMaxUsage(input) {
|
||||||
|
t.Fatalf("expected simulate=false when cache creation already exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
41
backend/internal/service/claude_tokenizer.go
Normal file
41
backend/internal/service/claude_tokenizer.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
tiktoken "github.com/pkoukk/tiktoken-go"
|
||||||
|
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
claudeTokenizerOnce sync.Once
|
||||||
|
claudeTokenizer *tiktoken.Tiktoken
|
||||||
|
)
|
||||||
|
|
||||||
|
func getClaudeTokenizer() *tiktoken.Tiktoken {
|
||||||
|
claudeTokenizerOnce.Do(func() {
|
||||||
|
// Use offline loader to avoid runtime dictionary download.
|
||||||
|
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
|
||||||
|
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
|
||||||
|
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||||
|
if err != nil {
|
||||||
|
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
claudeTokenizer = enc
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return claudeTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
|
||||||
|
enc := getClaudeTokenizer()
|
||||||
|
if enc == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
tokens := len(enc.EncodeOrdinary(text))
|
||||||
|
if tokens <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return tokens, true
|
||||||
|
}
|
||||||
@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
||||||
// Returns a map of accountID -> current concurrency count
|
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
||||||
|
// causing the entire batch to fail (which would show all concurrency as 0).
|
||||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
if len(accountIDs) == 0 {
|
if len(accountIDs) == 0 {
|
||||||
return map[int64]int{}, nil
|
return map[int64]int{}, nil
|
||||||
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
|
||||||
|
// Use a detached context so that a cancelled HTTP request doesn't cause
|
||||||
|
// the Redis pipeline to fail and return all-zero concurrency counts.
|
||||||
|
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
|||||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||||
boolVal, ok := v.(bool)
|
boolVal, ok := v.(bool)
|
||||||
assert.True(t, ok, "value should be bool")
|
assert.True(t, ok, "value should be a bool")
|
||||||
assert.True(t, boolVal)
|
assert.True(t, boolVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
196
backend/internal/service/gateway_claude_max_response_helpers.go
Normal file
196
backend/internal/service/gateway_claude_max_response_helpers.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeMaxResponseRewriteContext struct {
|
||||||
|
Parsed *ParsedRequest
|
||||||
|
Group *Group
|
||||||
|
}
|
||||||
|
|
||||||
|
type claudeMaxResponseRewriteContextKeyType struct{}
|
||||||
|
|
||||||
|
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
|
||||||
|
|
||||||
|
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
value := claudeMaxResponseRewriteContext{
|
||||||
|
Parsed: parsed,
|
||||||
|
Group: claudeMaxGroupFromGinContext(c),
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
|
||||||
|
if ctx == nil {
|
||||||
|
return claudeMaxResponseRewriteContext{}
|
||||||
|
}
|
||||||
|
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, exists := c.Get("api_key")
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey, ok := raw.(*APIKey)
|
||||||
|
if !ok || apiKey == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return apiKey.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, exists := c.Get("parsed_request")
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parsed, _ := raw.(*ParsedRequest)
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||||
|
var out claudeMaxCacheBillingOutcome
|
||||||
|
if usage == nil {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
|
||||||
|
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||||
|
var out claudeMaxCacheBillingOutcome
|
||||||
|
if usageObj == nil {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
usage := claudeUsageFromJSONMap(usageObj)
|
||||||
|
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
|
||||||
|
if out.Simulated {
|
||||||
|
rewriteClaudeUsageJSONMap(usageObj, usage)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
|
||||||
|
updated := body
|
||||||
|
var err error
|
||||||
|
|
||||||
|
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
|
||||||
|
var usage ClaudeUsage
|
||||||
|
if usageObj == nil {
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
|
||||||
|
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
|
||||||
|
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
|
||||||
|
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
|
||||||
|
|
||||||
|
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
|
||||||
|
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
|
||||||
|
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
|
||||||
|
if usageObj == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usageObj["input_tokens"] = usage.InputTokens
|
||||||
|
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
||||||
|
|
||||||
|
ccObj, _ := usageObj["cache_creation"].(map[string]any)
|
||||||
|
if ccObj == nil {
|
||||||
|
ccObj = make(map[string]any, 2)
|
||||||
|
usageObj["cache_creation"] = ccObj
|
||||||
|
}
|
||||||
|
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
|
||||||
|
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func usageIntFromAny(v any) int {
|
||||||
|
switch value := v.(type) {
|
||||||
|
case int:
|
||||||
|
return value
|
||||||
|
case int64:
|
||||||
|
return int(value)
|
||||||
|
case float64:
|
||||||
|
return int(value)
|
||||||
|
case json.Number:
|
||||||
|
if n, err := value.Int64(); err == nil {
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
|
||||||
|
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
|
||||||
|
group := claudeMaxGroupFromGinContext(c)
|
||||||
|
parsed := parsedRequestFromGinContext(c)
|
||||||
|
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
processor.SetUsageMapHook(func(usageMap map[string]any) {
|
||||||
|
svcUsage := claudeUsageFromJSONMap(usageMap)
|
||||||
|
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
|
||||||
|
if outcome.Simulated {
|
||||||
|
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
|
||||||
|
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
|
||||||
|
group := claudeMaxGroupFromGinContext(c)
|
||||||
|
parsed := parsedRequestFromGinContext(c)
|
||||||
|
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||||
|
return claudeResp
|
||||||
|
}
|
||||||
|
svcUsage := &ClaudeUsage{
|
||||||
|
InputTokens: agUsage.InputTokens,
|
||||||
|
OutputTokens: agUsage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||||||
|
}
|
||||||
|
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
|
||||||
|
if outcome.Simulated {
|
||||||
|
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
|
||||||
|
}
|
||||||
|
return claudeResp
|
||||||
|
}
|
||||||
199
backend/internal/service/gateway_record_usage_claude_max_test.go
Normal file
199
backend/internal/service/gateway_record_usage_claude_max_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageLogRepoRecordUsageStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
|
||||||
|
last *UsageLog
|
||||||
|
inserted bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) {
|
||||||
|
copied := *log
|
||||||
|
s.last = &copied
|
||||||
|
return s.inserted, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService {
|
||||||
|
return &GatewayService{
|
||||||
|
usageLogRepo: repo,
|
||||||
|
billingService: NewBillingService(&config.Config{}, nil),
|
||||||
|
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
|
||||||
|
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||||
|
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||||
|
|
||||||
|
groupID := int64(11)
|
||||||
|
input := &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "req-sim-1",
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 160,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ParsedRequest: &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "long cached context for prior turns",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "please summarize the logs and provide root cause analysis",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
SimulateClaudeMaxEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2},
|
||||||
|
Account: &Account{
|
||||||
|
ID: 3,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"cache_ttl_override_enabled": true,
|
||||||
|
"cache_ttl_override_target": "5m",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, repo.last)
|
||||||
|
|
||||||
|
log := repo.last
|
||||||
|
require.Equal(t, 80, log.InputTokens)
|
||||||
|
require.Equal(t, 80, log.CacheCreationTokens)
|
||||||
|
require.Equal(t, 0, log.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, 80, log.CacheCreation1hTokens)
|
||||||
|
require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
|
||||||
|
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||||
|
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||||
|
|
||||||
|
groupID := int64(12)
|
||||||
|
input := &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "req-sim-2",
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 40,
|
||||||
|
CacheCreationInputTokens: 120,
|
||||||
|
CacheCreation1hTokens: 120,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 2,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
SimulateClaudeMaxEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 3},
|
||||||
|
Account: &Account{
|
||||||
|
ID: 4,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"cache_ttl_override_enabled": true,
|
||||||
|
"cache_ttl_override_target": "5m",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, repo.last)
|
||||||
|
|
||||||
|
log := repo.last
|
||||||
|
require.Equal(t, 120, log.CacheCreationTokens)
|
||||||
|
require.Equal(t, 120, log.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||||
|
require.True(t, log.CacheTTLOverridden)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
|
||||||
|
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||||
|
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||||
|
|
||||||
|
groupID := int64(13)
|
||||||
|
input := &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "req-sim-3",
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 20,
|
||||||
|
CacheCreationInputTokens: 120,
|
||||||
|
CacheCreation5mTokens: 120,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 3,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
SimulateClaudeMaxEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 4},
|
||||||
|
Account: &Account{
|
||||||
|
ID: 5,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"cache_ttl_override_enabled": true,
|
||||||
|
"cache_ttl_override_target": "5m",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, repo.last)
|
||||||
|
|
||||||
|
log := repo.last
|
||||||
|
require.Equal(t, 20, log.InputTokens)
|
||||||
|
require.Equal(t, 120, log.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||||
|
require.Equal(t, 120, log.CacheCreationTokens)
|
||||||
|
require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override")
|
||||||
|
}
|
||||||
170
backend/internal/service/gateway_response_usage_sync_test.go
Normal file
170
backend/internal/service/gateway_response_usage_sync_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"cache_ttl_override_enabled": true,
|
||||||
|
"cache_ttl_override_target": "5m",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
ID: 99,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SimulateClaudeMaxEnabled: true,
|
||||||
|
}
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "long cached context",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "new user question",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: ioNopCloserBytes(upstreamBody),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
||||||
|
c.Set("api_key", &APIKey{Group: group})
|
||||||
|
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
||||||
|
|
||||||
|
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
|
||||||
|
var rendered struct {
|
||||||
|
Usage ClaudeUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
|
||||||
|
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
|
||||||
|
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
|
||||||
|
|
||||||
|
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
|
||||||
|
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
|
||||||
|
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
|
||||||
|
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
require.Greater(t, usage.CacheCreation1hTokens, 0)
|
||||||
|
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
||||||
|
require.Less(t, usage.InputTokens, 120)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 12,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"cache_ttl_override_enabled": true,
|
||||||
|
"cache_ttl_override_target": "5m",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
ID: 100,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SimulateClaudeMaxEnabled: false,
|
||||||
|
}
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "long cached context",
|
||||||
|
"cache_control": map[string]any{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "new user question",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: ioNopCloserBytes(upstreamBody),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
||||||
|
c.Set("api_key", &APIKey{Group: group})
|
||||||
|
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
||||||
|
|
||||||
|
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
|
||||||
|
require.Equal(t, 120, usage.InputTokens)
|
||||||
|
require.Equal(t, 0, usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, 0, usage.CacheCreation1hTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
|
||||||
|
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
|
||||||
|
}
|
||||||
|
|
||||||
|
type readCloserFromBytes struct {
|
||||||
|
*bytes.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readCloserFromBytes) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -56,6 +56,12 @@ const (
|
|||||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
claudeMaxMessageOverheadTokens = 3
|
||||||
|
claudeMaxBlockOverheadTokens = 1
|
||||||
|
claudeMaxUnknownContentTokens = 4
|
||||||
|
)
|
||||||
|
|
||||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||||
type forceCacheBillingKeyType struct{}
|
type forceCacheBillingKeyType struct{}
|
||||||
@@ -4424,6 +4430,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理正常响应
|
// 处理正常响应
|
||||||
|
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
|
||||||
|
|
||||||
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
||||||
if parsed.OnUpstreamAccepted != nil {
|
if parsed.OnUpstreamAccepted != nil {
|
||||||
@@ -5998,6 +6005,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
@@ -6011,6 +6034,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
skipAccountTTLOverride := false
|
||||||
|
|
||||||
pendingEventLines := make([]string, 0, 4)
|
pendingEventLines := make([]string, 0, 4)
|
||||||
|
|
||||||
@@ -6071,17 +6095,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
if msg, ok := event["message"].(map[string]any); ok {
|
if msg, ok := event["message"].(map[string]any); ok {
|
||||||
if u, ok := msg["usage"].(map[string]any); ok {
|
if u, ok := msg["usage"].(map[string]any); ok {
|
||||||
eventChanged = reconcileCachedTokens(u) || eventChanged
|
eventChanged = reconcileCachedTokens(u) || eventChanged
|
||||||
|
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
|
||||||
|
if claudeMaxOutcome.Simulated {
|
||||||
|
skipAccountTTLOverride = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if eventType == "message_delta" {
|
if eventType == "message_delta" {
|
||||||
if u, ok := event["usage"].(map[string]any); ok {
|
if u, ok := event["usage"].(map[string]any); ok {
|
||||||
eventChanged = reconcileCachedTokens(u) || eventChanged
|
eventChanged = reconcileCachedTokens(u) || eventChanged
|
||||||
|
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
|
||||||
|
if claudeMaxOutcome.Simulated {
|
||||||
|
skipAccountTTLOverride = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||||
if account.IsCacheTTLOverrideEnabled() {
|
if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
|
||||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||||
if eventType == "message_start" {
|
if eventType == "message_start" {
|
||||||
if msg, ok := event["message"].(map[string]any); ok {
|
if msg, ok := event["message"].(map[string]any); ok {
|
||||||
@@ -6187,6 +6219,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
lastDataAt = time.Now()
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
if firstTokenMs == nil && data != "[DONE]" {
|
if firstTokenMs == nil && data != "[DONE]" {
|
||||||
@@ -6220,6 +6253,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -6491,8 +6540,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
|
||||||
|
if claudeMaxOutcome.Simulated {
|
||||||
|
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
|
||||||
|
}
|
||||||
|
|
||||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||||
if account.IsCacheTTLOverrideEnabled() {
|
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
|
||||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||||
@@ -6558,6 +6612,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
|||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
|
ParsedRequest *ParsedRequest
|
||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
@@ -6674,9 +6729,19 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Max cache billing policy (group-level):
|
||||||
|
// - GatewayService 路径: Forward 已改写 usage(含 cache tokens)→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true(通过第二条件)
|
||||||
|
// - Antigravity 路径: Forward 中 hook 改写了客户端 SSE,但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true
|
||||||
|
var apiKeyGroup *Group
|
||||||
|
if apiKey != nil {
|
||||||
|
apiKeyGroup = apiKey.Group
|
||||||
|
}
|
||||||
|
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, input.ParsedRequest, apiKeyGroup, result.Model, account.ID)
|
||||||
|
simulatedClaudeMax := claudeMaxOutcome.Simulated ||
|
||||||
|
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, input.ParsedRequest) && hasCacheCreationTokens(result.Usage))
|
||||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
cacheTTLOverridden := false
|
cacheTTLOverridden := false
|
||||||
if account.IsCacheTTLOverrideEnabled() {
|
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "hello"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
|
||||||
|
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"cachedContent": {
|
||||||
|
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
|
||||||
|
},
|
||||||
|
"signature": "keep_me"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(cleaned, &got))
|
||||||
|
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
|
||||||
|
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
|
||||||
|
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
|
||||||
|
input := []byte(`{"contents":[invalid-json]}`)
|
||||||
|
|
||||||
|
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||||
|
|
||||||
|
require.Equal(t, input, cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
|
||||||
|
input := map[string]any{
|
||||||
|
"thoughtSignature": "sig_root",
|
||||||
|
"signature": "keep_signature",
|
||||||
|
"nested": []any{
|
||||||
|
map[string]any{
|
||||||
|
"thoughtSignature": "sig_nested",
|
||||||
|
"signature": "keep_nested_signature",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
|
||||||
|
require.Equal(t, "keep_signature", got["signature"])
|
||||||
|
|
||||||
|
nested, ok := got["nested"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
nestedMap, ok := nested[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
|
||||||
|
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
|
||||||
|
}
|
||||||
@@ -50,6 +50,9 @@ type Group struct {
|
|||||||
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
||||||
MCPXMLInject bool
|
MCPXMLInject bool
|
||||||
|
|
||||||
|
// Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格
|
||||||
|
SimulateClaudeMaxEnabled bool
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
// 可选值: claude, gemini_text, gemini_image
|
// 可选值: claude, gemini_text, gemini_image
|
||||||
SupportedModelScopes []string
|
SupportedModelScopes []string
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -146,6 +147,22 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
input = filterCodexInput(input, needsToolContinuation)
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
reqBody["input"] = input
|
reqBody["input"] = input
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
|
} else if inputStr, ok := reqBody["input"].(string); ok {
|
||||||
|
// ChatGPT codex endpoint requires input to be a list, not a string.
|
||||||
|
// Convert string input to the expected message array format.
|
||||||
|
trimmed := strings.TrimSpace(inputStr)
|
||||||
|
if trimmed != "" {
|
||||||
|
reqBody["input"] = []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": inputStr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reqBody["input"] = []any{}
|
||||||
|
}
|
||||||
|
result.Modified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -210,6 +227,29 @@ func normalizeCodexModel(model string) string {
|
|||||||
return "gpt-5.1"
|
return "gpt-5.1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SupportsVerbosity(model string) bool {
|
||||||
|
if !strings.HasPrefix(model, "gpt-") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var major, minor int
|
||||||
|
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
|
||||||
|
|
||||||
|
if major > 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if major < 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// gpt-5
|
||||||
|
if n == 1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return minor >= 3
|
||||||
|
}
|
||||||
|
|
||||||
func getNormalizedCodexModel(modelID string) string {
|
func getNormalizedCodexModel(modelID string) string {
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -249,6 +249,50 @@ func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *t
|
|||||||
require.Equal(t, "old instructions", instructions)
|
require.Equal(t, "old instructions", instructions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_StringInputConvertedToArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": "Hello, world!"}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
msg, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "message", msg["type"])
|
||||||
|
require.Equal(t, "user", msg["role"])
|
||||||
|
require.Equal(t, "Hello, world!", msg["content"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyStringInputBecomesEmptyArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": ""}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_WhitespaceStringInputBecomesEmptyArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": " "}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.4",
|
||||||
|
"input": "Run the tests",
|
||||||
|
"tools": []any{map[string]any{"type": "function", "name": "bash"}},
|
||||||
|
}
|
||||||
|
applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsInstructionsEmpty(t *testing.T) {
|
func TestIsInstructionsEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
512
backend/internal/service/openai_gateway_chat_completions.go
Normal file
512
backend/internal/service/openai_gateway_chat_completions.go
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||||||
|
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||||
|
// the response back to Chat Completions format. All account types (OAuth and API
|
||||||
|
// Key) go through the Responses API conversion path since the upstream only
|
||||||
|
// exposes the /v1/responses endpoint.
|
||||||
|
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
body []byte,
|
||||||
|
promptCacheKey string,
|
||||||
|
defaultMappedModel string,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// 1. Parse Chat Completions request
|
||||||
|
var chatReq apicompat.ChatCompletionsRequest
|
||||||
|
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||||||
|
}
|
||||||
|
originalModel := chatReq.Model
|
||||||
|
clientStream := chatReq.Stream
|
||||||
|
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||||||
|
|
||||||
|
// 2. Convert to Responses and forward
|
||||||
|
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||||||
|
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Model mapping
|
||||||
|
mappedModel := account.GetMappedModel(originalModel)
|
||||||
|
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||||
|
mappedModel = defaultMappedModel
|
||||||
|
}
|
||||||
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
|
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("original_model", originalModel),
|
||||||
|
zap.String("mapped_model", mappedModel),
|
||||||
|
zap.Bool("stream", clientStream),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||||
|
responsesBody, err := json.Marshal(responsesReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
var reqBody map[string]any
|
||||||
|
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||||
|
}
|
||||||
|
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
if codexResult.PromptCacheKey != "" {
|
||||||
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
|
} else if promptCacheKey != "" {
|
||||||
|
reqBody["prompt_cache_key"] = promptCacheKey
|
||||||
|
}
|
||||||
|
responsesBody, err = json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Get access token
|
||||||
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Build upstream request
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if promptCacheKey != "" {
|
||||||
|
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Send request
|
||||||
|
proxyURL := ""
|
||||||
|
if account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 8. Handle error response with failover
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
}
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Handle normal response
|
||||||
|
var result *OpenAIForwardResult
|
||||||
|
var handleErr error
|
||||||
|
if clientStream {
|
||||||
|
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
|
||||||
|
} else {
|
||||||
|
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||||
|
if handleErr == nil && result != nil {
|
||||||
|
if responsesReq.ServiceTier != "" {
|
||||||
|
st := responsesReq.ServiceTier
|
||||||
|
result.ServiceTier = &st
|
||||||
|
}
|
||||||
|
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||||||
|
re := responsesReq.Reasoning.Effort
|
||||||
|
result.ReasoningEffort = &re
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||||
|
if handleErr == nil && account.Type == AccountTypeOAuth {
|
||||||
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, handleErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||||||
|
// OpenAI Chat Completions error format.
|
||||||
|
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
|
||||||
|
// upstream, finds the terminal event, converts to a Chat Completions JSON
|
||||||
|
// response, and writes it to the client.
|
||||||
|
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
originalModel string,
|
||||||
|
mappedModel string,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
var finalResponse *apicompat.ResponsesResponse
|
||||||
|
var usage OpenAIUsage
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := line[6:]
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||||
|
event.Response != nil {
|
||||||
|
finalResponse = event.Response
|
||||||
|
if event.Response.Usage != nil {
|
||||||
|
usage = OpenAIUsage{
|
||||||
|
InputTokens: event.Response.Usage.InputTokens,
|
||||||
|
OutputTokens: event.Response.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if event.Response.Usage.InputTokensDetails != nil {
|
||||||
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn("openai chat_completions buffered: read error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalResponse == nil {
|
||||||
|
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||||||
|
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
|
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
|
||||||
|
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, chatResp)
|
||||||
|
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: mappedModel,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatStreamingResponse reads Responses SSE events from upstream,
|
||||||
|
// converts each to Chat Completions SSE chunks, and writes them to the client.
|
||||||
|
func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
originalModel string,
|
||||||
|
mappedModel string,
|
||||||
|
includeUsage bool,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
state := apicompat.NewResponsesEventToChatState()
|
||||||
|
state.Model = originalModel
|
||||||
|
state.IncludeUsage = includeUsage
|
||||||
|
|
||||||
|
var usage OpenAIUsage
|
||||||
|
var firstTokenMs *int
|
||||||
|
firstChunk := true
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: mappedModel,
|
||||||
|
Stream: true,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
processDataLine := func(payload string) bool {
|
||||||
|
if firstChunk {
|
||||||
|
firstChunk = false
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions stream: failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract usage from completion events
|
||||||
|
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||||
|
event.Response != nil && event.Response.Usage != nil {
|
||||||
|
usage = OpenAIUsage{
|
||||||
|
InputTokens: event.Response.Usage.InputTokens,
|
||||||
|
OutputTokens: event.Response.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if event.Response.Usage.InputTokensDetails != nil {
|
||||||
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
|
if err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(chunks) > 0 {
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
|
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||||||
|
for _, chunk := range finalChunks {
|
||||||
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Send [DONE] sentinel
|
||||||
|
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||||
|
c.Writer.Flush()
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
handleScanErr := func(err error) {
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn("openai chat_completions stream: read error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine keepalive interval
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// No keepalive: fast synchronous path
|
||||||
|
if keepaliveInterval <= 0 {
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if processDataLine(line[6:]) {
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handleScanErr(scanner.Err())
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
|
||||||
|
// With keepalive: goroutine + channel + select
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev scanEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for scanner.Scan() {
|
||||||
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
_ = sendEvent(scanEvent{err: err})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
handleScanErr(ev.err)
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
line := ev.line
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if processDataLine(line[6:]) {
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-keepaliveTicker.C:
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Send SSE comment as keepalive
|
||||||
|
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
|
||||||
|
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Non-failover error: return Anthropic-formatted error to client
|
// Non-failover error: return Anthropic-formatted error to client
|
||||||
@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
|||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
account *Account,
|
account *Account,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
|
|
||||||
// Record upstream error details for ops logging
|
|
||||||
upstreamDetail := ""
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = 2048
|
|
||||||
}
|
|
||||||
upstreamDetail = truncateString(string(body), maxBytes)
|
|
||||||
}
|
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
||||||
|
|
||||||
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
|
|
||||||
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
|
||||||
c, account.Platform, resp.StatusCode, body,
|
|
||||||
http.StatusBadGateway, "api_error", "Upstream request failed",
|
|
||||||
); matched {
|
|
||||||
writeAnthropicError(c, status, errType, errMsg)
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
upstreamMsg = errMsg
|
|
||||||
}
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
errType := "api_error"
|
|
||||||
switch {
|
|
||||||
case resp.StatusCode == 400:
|
|
||||||
errType = "invalid_request_error"
|
|
||||||
case resp.StatusCode == 404:
|
|
||||||
errType = "not_found_error"
|
|
||||||
case resp.StatusCode == 429:
|
|
||||||
errType = "rate_limit_error"
|
|
||||||
case resp.StatusCode >= 500:
|
|
||||||
errType = "api_error"
|
|
||||||
}
|
|
||||||
|
|
||||||
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
|
|
||||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ const (
|
|||||||
openAIWSRetryJitterRatioDefault = 0.2
|
openAIWSRetryJitterRatioDefault = 0.2
|
||||||
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
||||||
codexCLIVersion = "0.104.0"
|
codexCLIVersion = "0.104.0"
|
||||||
|
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
|
||||||
|
openAICodexSnapshotPersistMinInterval = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAI allowed headers whitelist (for non-passthrough).
|
// OpenAI allowed headers whitelist (for non-passthrough).
|
||||||
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
|
|||||||
nonRetryableFastFallback atomic.Int64
|
nonRetryableFastFallback atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountWriteThrottle struct {
|
||||||
|
minInterval time.Duration
|
||||||
|
mu sync.Mutex
|
||||||
|
lastByID map[int64]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
|
||||||
|
return &accountWriteThrottle{
|
||||||
|
minInterval: minInterval,
|
||||||
|
lastByID: make(map[int64]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
|
||||||
|
if t == nil || id <= 0 || t.minInterval <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.lastByID[id] = now
|
||||||
|
|
||||||
|
if len(t.lastByID) > 4096 {
|
||||||
|
cutoff := now.Add(-4 * t.minInterval)
|
||||||
|
for accountID, writtenAt := range t.lastByID {
|
||||||
|
if writtenAt.Before(cutoff) {
|
||||||
|
delete(t.lastByID, accountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
|
||||||
|
|
||||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||||
type OpenAIGatewayService struct {
|
type OpenAIGatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -289,6 +331,7 @@ type OpenAIGatewayService struct {
|
|||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||||
|
codexSnapshotThrottle *accountWriteThrottle
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||||
@@ -329,17 +372,25 @@ func NewOpenAIGatewayService(
|
|||||||
nil,
|
nil,
|
||||||
"service.openai_gateway",
|
"service.openai_gateway",
|
||||||
),
|
),
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
openAITokenProvider: openAITokenProvider,
|
openAITokenProvider: openAITokenProvider,
|
||||||
toolCorrector: NewCodexToolCorrector(),
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
svc.logOpenAIWSModeBootstrap()
|
svc.logOpenAIWSModeBootstrap()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||||
|
if s != nil && s.codexSnapshotThrottle != nil {
|
||||||
|
return s.codexSnapshotThrottle
|
||||||
|
}
|
||||||
|
return defaultOpenAICodexSnapshotPersistThrottle
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||||
return &billingDeps{
|
return &billingDeps{
|
||||||
accountRepo: s.accountRepo,
|
accountRepo: s.accountRepo,
|
||||||
@@ -1716,6 +1767,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
bodyModified = true
|
bodyModified = true
|
||||||
markPatchSet("model", normalizedModel)
|
markPatchSet("model", normalizedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
||||||
|
// 确保高版本模型向低版本模型映射不报错
|
||||||
|
if !SupportsVerbosity(normalizedModel) {
|
||||||
|
if text, ok := reqBody["text"].(map[string]any); ok {
|
||||||
|
delete(text, "verbosity")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||||
@@ -2947,6 +3006,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
|||||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compatErrorWriter is the signature for format-specific error writers used by
|
||||||
|
// the compat paths (Chat Completions and Anthropic Messages).
|
||||||
|
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
|
||||||
|
|
||||||
|
// handleCompatErrorResponse is the shared non-failover error handler for the
|
||||||
|
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
|
||||||
|
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
|
||||||
|
// tracking, secondary failover) but delegates the final error write to the
|
||||||
|
// format-specific writer function.
|
||||||
|
func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
writeError compatErrorWriter,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
|
// Apply error passthrough rules
|
||||||
|
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
||||||
|
c, account.Platform, resp.StatusCode, body,
|
||||||
|
http.StatusBadGateway, "api_error", "Upstream request failed",
|
||||||
|
); matched {
|
||||||
|
writeError(c, status, errType, errMsg)
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
upstreamMsg = errMsg
|
||||||
|
}
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check custom error codes — if the account does not handle this status,
|
||||||
|
// return a generic error without exposing upstream details.
|
||||||
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track rate limits and decide whether to trigger secondary failover.
|
||||||
|
shouldDisable := false
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
||||||
|
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
kind := "http_error"
|
||||||
|
if shouldDisable {
|
||||||
|
kind = "failover"
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: kind,
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
if shouldDisable {
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: body,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map status code to error type and write response
|
||||||
|
errType := "api_error"
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode == 400:
|
||||||
|
errType = "invalid_request_error"
|
||||||
|
case resp.StatusCode == 404:
|
||||||
|
errType = "not_found_error"
|
||||||
|
case resp.StatusCode == 429:
|
||||||
|
errType = "rate_limit_error"
|
||||||
|
case resp.StatusCode >= 500:
|
||||||
|
errType = "api_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
writeError(c, resp.StatusCode, errType, upstreamMsg)
|
||||||
|
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
// openaiStreamingResult streaming response result
|
// openaiStreamingResult streaming response result
|
||||||
type openaiStreamingResult struct {
|
type openaiStreamingResult struct {
|
||||||
usage *OpenAIUsage
|
usage *OpenAIUsage
|
||||||
@@ -4050,11 +4223,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
if len(updates) == 0 && resetAt == nil {
|
if len(updates) == 0 && resetAt == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
|
||||||
|
if !shouldPersistUpdates && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if len(updates) > 0 {
|
if shouldPersistUpdates {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}
|
}
|
||||||
if resetAt != nil {
|
if resetAt != nil {
|
||||||
|
|||||||
@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
|
||||||
|
repo := &openAICodexSnapshotAsyncRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 2),
|
||||||
|
rateLimitCh: make(chan time.Time, 2),
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
|
||||||
|
}
|
||||||
|
snapshot := &OpenAICodexUsageSnapshot{
|
||||||
|
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||||
|
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||||
|
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||||
|
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||||
|
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||||
|
SecondaryWindowMinutes: ptrIntWS(300),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-repo.updateExtraCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("等待第一次 codex 快照落库超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
t.Fatalf("unexpected second codex snapshot write: %v", updates)
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||||
func ptrIntWS(v int) *int { return &v }
|
func ptrIntWS(v int) *int { return &v }
|
||||||
|
|
||||||
|
|||||||
@@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
|||||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
})), true
|
})), true
|
||||||
|
case "group_rate_limit_ratio":
|
||||||
|
if groupID == nil || *groupID <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if availability.Group == nil || availability.Group.TotalAccounts <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true
|
||||||
|
case "account_error_ratio":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
total := int64(len(availability.Accounts))
|
||||||
|
if total <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
|
})
|
||||||
|
return (float64(errorCount) / float64(total)) * 100, true
|
||||||
|
case "overload_account_count":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.IsOverloaded
|
||||||
|
})), true
|
||||||
}
|
}
|
||||||
|
|
||||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||||
|
|||||||
@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
|
|||||||
if acc.ID <= 0 {
|
if acc.ID <= 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c := acc.Concurrency
|
lf := acc.EffectiveLoadFactor()
|
||||||
if c <= 0 {
|
if prev, ok := unique[acc.ID]; !ok || lf > prev {
|
||||||
c = 1
|
unique[acc.ID] = lf
|
||||||
}
|
|
||||||
if prev, ok := unique[acc.ID]; !ok || c > prev {
|
|
||||||
unique[acc.ID] = c
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
|
|||||||
}
|
}
|
||||||
batch = append(batch, AccountWithConcurrency{
|
batch = append(batch, AccountWithConcurrency{
|
||||||
ID: acc.ID,
|
ID: acc.ID,
|
||||||
MaxConcurrency: acc.Concurrency,
|
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(batch) == 0 {
|
if len(batch) == 0 {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
type OpsRepository interface {
|
type OpsRepository interface {
|
||||||
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||||
|
BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
|
||||||
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
||||||
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
||||||
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
|
|
||||||
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
|
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
|
||||||
type opsRepoMock struct {
|
type opsRepoMock struct {
|
||||||
|
InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||||
|
BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
|
||||||
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
||||||
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
||||||
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
||||||
@@ -14,9 +16,19 @@ type opsRepoMock struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if m.InsertErrorLogFn != nil {
|
||||||
|
return m.InsertErrorLogFn(ctx, input)
|
||||||
|
}
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if m.BatchInsertErrorLogsFn != nil {
|
||||||
|
return m.BatchInsertErrorLogsFn(ctx, inputs)
|
||||||
|
}
|
||||||
|
return int64(len(inputs)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
||||||
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
|
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
|
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
|
||||||
if entry == nil {
|
prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordError prepare failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil {
|
||||||
|
// Never bubble up to gateway; best-effort logging.
|
||||||
|
log.Printf("[Ops] RecordError failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
item, ok, err := s.prepareErrorLogInput(ctx, entry, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
prepared = append(prepared, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(prepared) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(prepared) == 1 {
|
||||||
|
_, err := s.opsRepo.InsertErrorLog(ctx, prepared[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err)
|
||||||
|
var firstErr error
|
||||||
|
for _, entry := range prepared {
|
||||||
|
if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = insertErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) {
|
||||||
|
if entry == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
if !s.IsMonitoringEnabled(ctx) {
|
if !s.IsMonitoringEnabled(ctx) {
|
||||||
return nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
if s.opsRepo == nil {
|
if s.opsRepo == nil {
|
||||||
return nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure timestamps are always populated.
|
// Ensure timestamps are always populated.
|
||||||
@@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanitize + serialize upstream error events list.
|
if err := sanitizeOpsUpstreamErrors(entry); err != nil {
|
||||||
if len(entry.UpstreamErrors) > 0 {
|
return nil, false, err
|
||||||
const maxEvents = 32
|
}
|
||||||
events := entry.UpstreamErrors
|
|
||||||
if len(events) > maxEvents {
|
return entry, true, nil
|
||||||
events = events[len(events)-maxEvents:]
|
}
|
||||||
|
|
||||||
|
func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||||
|
if entry == nil || len(entry.UpstreamErrors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxEvents = 32
|
||||||
|
events := entry.UpstreamErrors
|
||||||
|
if len(events) > maxEvents {
|
||||||
|
events = events[len(events)-maxEvents:]
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
|
||||||
|
for _, ev := range events {
|
||||||
|
if ev == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out := *ev
|
||||||
|
|
||||||
|
out.Platform = strings.TrimSpace(out.Platform)
|
||||||
|
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||||
|
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||||
|
|
||||||
|
if out.AccountID < 0 {
|
||||||
|
out.AccountID = 0
|
||||||
|
}
|
||||||
|
if out.UpstreamStatusCode < 0 {
|
||||||
|
out.UpstreamStatusCode = 0
|
||||||
|
}
|
||||||
|
if out.AtUnixMs < 0 {
|
||||||
|
out.AtUnixMs = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
|
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
|
||||||
for _, ev := range events {
|
msg = truncateString(msg, 2048)
|
||||||
if ev == nil {
|
out.Message = msg
|
||||||
continue
|
|
||||||
}
|
|
||||||
out := *ev
|
|
||||||
|
|
||||||
out.Platform = strings.TrimSpace(out.Platform)
|
detail := strings.TrimSpace(out.Detail)
|
||||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
if detail != "" {
|
||||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
|
||||||
|
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
||||||
|
out.Detail = sanitizedDetail
|
||||||
|
} else {
|
||||||
|
out.Detail = ""
|
||||||
|
}
|
||||||
|
|
||||||
if out.AccountID < 0 {
|
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
||||||
out.AccountID = 0
|
if out.UpstreamRequestBody != "" {
|
||||||
}
|
// Reuse the same sanitization/trimming strategy as request body storage.
|
||||||
if out.UpstreamStatusCode < 0 {
|
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
||||||
out.UpstreamStatusCode = 0
|
sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
||||||
}
|
if sanitizedBody != "" {
|
||||||
if out.AtUnixMs < 0 {
|
out.UpstreamRequestBody = sanitizedBody
|
||||||
out.AtUnixMs = 0
|
if truncated {
|
||||||
}
|
out.Kind = strings.TrimSpace(out.Kind)
|
||||||
|
if out.Kind == "" {
|
||||||
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
|
out.Kind = "upstream"
|
||||||
msg = truncateString(msg, 2048)
|
|
||||||
out.Message = msg
|
|
||||||
|
|
||||||
detail := strings.TrimSpace(out.Detail)
|
|
||||||
if detail != "" {
|
|
||||||
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
|
|
||||||
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
|
||||||
out.Detail = sanitizedDetail
|
|
||||||
} else {
|
|
||||||
out.Detail = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
|
||||||
if out.UpstreamRequestBody != "" {
|
|
||||||
// Reuse the same sanitization/trimming strategy as request body storage.
|
|
||||||
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
|
||||||
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
|
||||||
if sanitized != "" {
|
|
||||||
out.UpstreamRequestBody = sanitized
|
|
||||||
if truncated {
|
|
||||||
out.Kind = strings.TrimSpace(out.Kind)
|
|
||||||
if out.Kind == "" {
|
|
||||||
out.Kind = "upstream"
|
|
||||||
}
|
|
||||||
out.Kind = out.Kind + ":request_body_truncated"
|
|
||||||
}
|
}
|
||||||
} else {
|
out.Kind = out.Kind + ":request_body_truncated"
|
||||||
out.UpstreamRequestBody = ""
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out.UpstreamRequestBody = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop fully-empty events (can happen if only status code was known).
|
|
||||||
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
evCopy := out
|
|
||||||
sanitized = append(sanitized, &evCopy)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
|
// Drop fully-empty events (can happen if only status code was known).
|
||||||
entry.UpstreamErrors = nil
|
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
evCopy := out
|
||||||
|
sanitized = append(sanitized, &evCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
|
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
|
||||||
// Never bubble up to gateway; best-effort logging.
|
entry.UpstreamErrors = nil
|
||||||
log.Printf("[Ops] RecordError failed: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
103
backend/internal/service/ops_service_batch_test.go
Normal file
103
backend/internal/service/ops_service_batch_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var captured []*OpsInsertErrorLogInput
|
||||||
|
repo := &opsRepoMock{
|
||||||
|
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
captured = append(captured, inputs...)
|
||||||
|
return int64(len(inputs)), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
msg := " upstream failed: https://example.com?access_token=secret-value "
|
||||||
|
detail := `{"authorization":"Bearer secret-token"}`
|
||||||
|
entries := []*OpsInsertErrorLogInput{
|
||||||
|
{
|
||||||
|
ErrorBody: `{"error":"bad","access_token":"secret"}`,
|
||||||
|
UpstreamStatusCode: intPtr(-10),
|
||||||
|
UpstreamErrorMessage: strPtr(msg),
|
||||||
|
UpstreamErrorDetail: strPtr(detail),
|
||||||
|
UpstreamErrors: []*OpsUpstreamErrorEvent{
|
||||||
|
{
|
||||||
|
AccountID: -2,
|
||||||
|
UpstreamStatusCode: 429,
|
||||||
|
Message: " token leaked ",
|
||||||
|
Detail: `{"refresh_token":"secret"}`,
|
||||||
|
UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ErrorPhase: "upstream",
|
||||||
|
ErrorType: "upstream_error",
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, svc.RecordErrorBatch(context.Background(), entries))
|
||||||
|
require.Len(t, captured, 2)
|
||||||
|
|
||||||
|
first := captured[0]
|
||||||
|
require.Equal(t, "internal", first.ErrorPhase)
|
||||||
|
require.Equal(t, "api_error", first.ErrorType)
|
||||||
|
require.Nil(t, first.UpstreamStatusCode)
|
||||||
|
require.NotNil(t, first.UpstreamErrorMessage)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorMessage, "secret-value")
|
||||||
|
require.Contains(t, *first.UpstreamErrorMessage, "access_token=***")
|
||||||
|
require.NotNil(t, first.UpstreamErrorDetail)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorDetail, "secret-token")
|
||||||
|
require.NotContains(t, first.ErrorBody, "secret")
|
||||||
|
require.Nil(t, first.UpstreamErrors)
|
||||||
|
require.NotNil(t, first.UpstreamErrorsJSON)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorsJSON, "secret")
|
||||||
|
require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]")
|
||||||
|
|
||||||
|
second := captured[1]
|
||||||
|
require.Equal(t, "upstream", second.ErrorPhase)
|
||||||
|
require.Equal(t, "upstream_error", second.ErrorType)
|
||||||
|
require.False(t, second.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
batchCalls int
|
||||||
|
singleCalls int
|
||||||
|
)
|
||||||
|
repo := &opsRepoMock{
|
||||||
|
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
batchCalls++
|
||||||
|
return 0, errors.New("batch failed")
|
||||||
|
},
|
||||||
|
InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
singleCalls++
|
||||||
|
return int64(singleCalls), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{
|
||||||
|
{ErrorMessage: "first"},
|
||||||
|
{ErrorMessage: "second"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, batchCalls)
|
||||||
|
require.Equal(t, 2, singleCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(v string) *string {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
166
backend/internal/service/subscription_reset_quota_test.go
Normal file
166
backend/internal/service/subscription_reset_quota_test.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage,
|
||||||
|
// 其余方法继承 userSubRepoNoop(panic)。
|
||||||
|
type resetQuotaUserSubRepoStub struct {
|
||||||
|
userSubRepoNoop
|
||||||
|
|
||||||
|
sub *UserSubscription
|
||||||
|
|
||||||
|
resetDailyCalled bool
|
||||||
|
resetWeeklyCalled bool
|
||||||
|
resetDailyErr error
|
||||||
|
resetWeeklyErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
|
||||||
|
if r.sub == nil || r.sub.ID != id {
|
||||||
|
return nil, ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
cp := *r.sub
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error {
|
||||||
|
r.resetDailyCalled = true
|
||||||
|
if r.resetDailyErr == nil && r.sub != nil {
|
||||||
|
r.sub.DailyUsageUSD = 0
|
||||||
|
r.sub.DailyWindowStart = &windowStart
|
||||||
|
}
|
||||||
|
return r.resetDailyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error {
|
||||||
|
r.resetWeeklyCalled = true
|
||||||
|
return r.resetWeeklyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService {
|
||||||
|
return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetBoth(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 1, true, true)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||||
|
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 2, true, false)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||||
|
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 3, false, true)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
|
||||||
|
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 7, false, false)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, ErrInvalidInput)
|
||||||
|
require.False(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{sub: nil}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 999, true, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, ErrSubscriptionNotFound)
|
||||||
|
require.False(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
|
||||||
|
dbErr := errors.New("db error")
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20},
|
||||||
|
resetDailyErr: dbErr,
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 4, true, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, dbErr)
|
||||||
|
require.True(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) {
|
||||||
|
dbErr := errors.New("db error")
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20},
|
||||||
|
resetWeeklyErr: dbErr,
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 5, false, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, dbErr)
|
||||||
|
require.True(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{
|
||||||
|
ID: 6,
|
||||||
|
UserID: 10,
|
||||||
|
GroupID: 20,
|
||||||
|
DailyUsageUSD: 99.9,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 6, true, false)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
// ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零,
|
||||||
|
// 服务应返回第二次 GetByID 的刷新值而非初始的 99.9
|
||||||
|
require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量")
|
||||||
|
require.True(t, stub.resetDailyCalled)
|
||||||
|
}
|
||||||
@@ -31,6 +31,7 @@ var (
|
|||||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||||
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
||||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||||
|
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
|
||||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||||
@@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
|
|||||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminResetQuota manually resets the daily and/or weekly usage windows.
|
||||||
|
// Uses startOfDay(now) as the new window start, matching automatic resets.
|
||||||
|
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
|
||||||
|
if !resetDaily && !resetWeekly {
|
||||||
|
return nil, ErrInvalidInput
|
||||||
|
}
|
||||||
|
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
windowStart := startOfDay(time.Now())
|
||||||
|
if resetDaily {
|
||||||
|
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resetWeekly {
|
||||||
|
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Invalidate caches, same as CheckAndResetWindows
|
||||||
|
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||||
|
if s.billingCacheService != nil {
|
||||||
|
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||||
|
}
|
||||||
|
// Return the refreshed subscription from DB
|
||||||
|
return s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckAndResetWindows 检查并重置过期的窗口
|
// CheckAndResetWindows 检查并重置过期的窗口
|
||||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
|
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
|
||||||
// 使用当天零点作为新窗口起始时间
|
// 使用当天零点作为新窗口起始时间
|
||||||
|
|||||||
@@ -2,6 +2,13 @@ package service
|
|||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
|
// UserGroupRateEntry 分组下用户专属倍率条目
|
||||||
|
type UserGroupRateEntry struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
UserEmail string `json:"user_email"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||||
type UserGroupRateRepository interface {
|
type UserGroupRateRepository interface {
|
||||||
@@ -13,6 +20,9 @@ type UserGroupRateRepository interface {
|
|||||||
// 如果未设置专属倍率,返回 nil
|
// 如果未设置专属倍率,返回 nil
|
||||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||||
|
|
||||||
|
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||||
|
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
|
|
||||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||||
|
|||||||
42
backend/migrations/056_add_sonnet46_to_model_mapping.sql
Normal file
42
backend/migrations/056_add_sonnet46_to_model_mapping.sql
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
|
||||||
|
--
|
||||||
|
-- Background:
|
||||||
|
-- Antigravity now supports claude-sonnet-4-6
|
||||||
|
--
|
||||||
|
-- Strategy:
|
||||||
|
-- Directly overwrite the entire model_mapping with updated mappings
|
||||||
|
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET credentials = jsonb_set(
|
||||||
|
credentials,
|
||||||
|
'{model_mapping}',
|
||||||
|
'{
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||||
|
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||||
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
|
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||||
|
}'::jsonb
|
||||||
|
)
|
||||||
|
WHERE platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND credentials->'model_mapping' IS NOT NULL;
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user