mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cce0a8877 | ||
|
|
225fd035ae | ||
|
|
fb7d1346b5 | ||
|
|
491a744481 | ||
|
|
f366026435 | ||
|
|
1a0d4ed668 | ||
|
|
63a8c76946 | ||
|
|
f355a68bc9 | ||
|
|
c87e6526c1 | ||
|
|
af3a5076d6 | ||
|
|
18f2e21414 | ||
|
|
8a8cdeebb4 | ||
|
|
12b33f4ea4 | ||
|
|
01b3a09d7d | ||
|
|
0d6c1c7790 | ||
|
|
95e366b6c6 | ||
|
|
77701143bf | ||
|
|
02dea7b09b | ||
|
|
c26f93c4a0 | ||
|
|
c826ac28ef | ||
|
|
1893b0eb30 | ||
|
|
05527b13db | ||
|
|
ae5d9c8bfc | ||
|
|
9117c2a4ec | ||
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 |
@@ -86,6 +86,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -216,6 +217,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -273,6 +278,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -402,6 +408,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -41,6 +41,8 @@ type Account struct {
|
||||
ProxyID *int64 `json:"proxy_id,omitempty"`
|
||||
// Concurrency holds the value of the "concurrency" field.
|
||||
Concurrency int `json:"concurrency,omitempty"`
|
||||
// LoadFactor holds the value of the "load_factor" field.
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Concurrency = int(value.Int64)
|
||||
}
|
||||
case account.FieldLoadFactor:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
|
||||
} else if value.Valid {
|
||||
_m.LoadFactor = new(int)
|
||||
*_m.LoadFactor = int(value.Int64)
|
||||
}
|
||||
case account.FieldPriority:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field priority", values[i])
|
||||
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
|
||||
builder.WriteString("concurrency=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.LoadFactor; v != nil {
|
||||
builder.WriteString("load_factor=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
FieldProxyID = "proxy_id"
|
||||
// FieldConcurrency holds the string denoting the concurrency field in the database.
|
||||
FieldConcurrency = "concurrency"
|
||||
// FieldLoadFactor holds the string denoting the load_factor field in the database.
|
||||
FieldLoadFactor = "load_factor"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
@@ -121,6 +123,7 @@ var Columns = []string{
|
||||
FieldExtra,
|
||||
FieldProxyID,
|
||||
FieldConcurrency,
|
||||
FieldLoadFactor,
|
||||
FieldPriority,
|
||||
FieldRateMultiplier,
|
||||
FieldStatus,
|
||||
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByLoadFactor orders the results by the load_factor field.
|
||||
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPriority orders the results by the priority field.
|
||||
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
|
||||
@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
|
||||
func LoadFactor(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
|
||||
func Priority(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
|
||||
func LoadFactorEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
|
||||
func LoadFactorNEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIn applies the In predicate on the "load_factor" field.
|
||||
func LoadFactorIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
|
||||
func LoadFactorNotIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorGT applies the GT predicate on the "load_factor" field.
|
||||
func LoadFactorGT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
|
||||
func LoadFactorGTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLT applies the LT predicate on the "load_factor" field.
|
||||
func LoadFactorLT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
|
||||
func LoadFactorLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
|
||||
func LoadFactorIsNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
|
||||
func LoadFactorNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// PriorityEQ applies the EQ predicate on the "priority" field.
|
||||
func PriorityEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
|
||||
@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
|
||||
_c.mutation.SetLoadFactor(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetLoadFactor(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
|
||||
_c.mutation.SetPriority(v)
|
||||
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
|
||||
_node.Concurrency = value
|
||||
}
|
||||
if value, ok := _c.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
_node.LoadFactor = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
_node.Priority = value
|
||||
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
|
||||
u.Set(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
|
||||
u.Add(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
|
||||
u.SetNull(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
|
||||
u.Set(account.FieldPriority, v)
|
||||
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -106,6 +106,7 @@ var (
|
||||
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
@@ -132,7 +133,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -151,52 +152,52 @@ var (
|
||||
{
|
||||
Name: "account_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_last_used_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
Columns: []*schema.Column{AccountsColumns[16]},
|
||||
},
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
},
|
||||
{
|
||||
Name: "account_platform_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
|
||||
@@ -2260,6 +2260,8 @@ type AccountMutation struct {
|
||||
extra *map[string]interface{}
|
||||
concurrency *int
|
||||
addconcurrency *int
|
||||
load_factor *int
|
||||
addload_factor *int
|
||||
priority *int
|
||||
addpriority *int
|
||||
rate_multiplier *float64
|
||||
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
|
||||
m.addconcurrency = nil
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (m *AccountMutation) SetLoadFactor(i int) {
|
||||
m.load_factor = &i
|
||||
m.addload_factor = nil
|
||||
}
|
||||
|
||||
// LoadFactor returns the value of the "load_factor" field in the mutation.
|
||||
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
|
||||
v := m.load_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
|
||||
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
|
||||
}
|
||||
return oldValue.LoadFactor, nil
|
||||
}
|
||||
|
||||
// AddLoadFactor adds i to the "load_factor" field.
|
||||
func (m *AccountMutation) AddLoadFactor(i int) {
|
||||
if m.addload_factor != nil {
|
||||
*m.addload_factor += i
|
||||
} else {
|
||||
m.addload_factor = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
|
||||
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
|
||||
v := m.addload_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (m *AccountMutation) ClearLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
m.clearedFields[account.FieldLoadFactor] = struct{}{}
|
||||
}
|
||||
|
||||
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
|
||||
func (m *AccountMutation) LoadFactorCleared() bool {
|
||||
_, ok := m.clearedFields[account.FieldLoadFactor]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetLoadFactor resets all changes to the "load_factor" field.
|
||||
func (m *AccountMutation) ResetLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
delete(m.clearedFields, account.FieldLoadFactor)
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (m *AccountMutation) SetPriority(i int) {
|
||||
m.priority = &i
|
||||
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AccountMutation) Fields() []string {
|
||||
fields := make([]string, 0, 27)
|
||||
fields := make([]string, 0, 28)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, account.FieldCreatedAt)
|
||||
}
|
||||
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
|
||||
if m.concurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.load_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.priority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ProxyID()
|
||||
case account.FieldConcurrency:
|
||||
return m.Concurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.LoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.Priority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
||||
return m.OldProxyID(ctx)
|
||||
case account.FieldConcurrency:
|
||||
return m.OldConcurrency(ctx)
|
||||
case account.FieldLoadFactor:
|
||||
return m.OldLoadFactor(ctx)
|
||||
case account.FieldPriority:
|
||||
return m.OldPriority(ctx)
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
|
||||
if m.addconcurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.addload_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.addpriority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
||||
switch name {
|
||||
case account.FieldConcurrency:
|
||||
return m.AddedConcurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.AddedLoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.AddedPriority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(account.FieldProxyID) {
|
||||
fields = append(fields, account.FieldProxyID)
|
||||
}
|
||||
if m.FieldCleared(account.FieldLoadFactor) {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.FieldCleared(account.FieldErrorMessage) {
|
||||
fields = append(fields, account.FieldErrorMessage)
|
||||
}
|
||||
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
|
||||
case account.FieldProxyID:
|
||||
m.ClearProxyID()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ClearLoadFactor()
|
||||
return nil
|
||||
case account.FieldErrorMessage:
|
||||
m.ClearErrorMessage()
|
||||
return nil
|
||||
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
||||
case account.FieldConcurrency:
|
||||
m.ResetConcurrency()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ResetLoadFactor()
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
m.ResetPriority()
|
||||
return nil
|
||||
@@ -10191,7 +10298,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 30)
|
||||
fields := make([]string, 0, 31)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
|
||||
@@ -212,29 +212,29 @@ func init() {
|
||||
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
|
||||
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
|
||||
// accountDescPriority is the schema descriptor for priority field.
|
||||
accountDescPriority := accountFields[8].Descriptor()
|
||||
accountDescPriority := accountFields[9].Descriptor()
|
||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||
accountDescRateMultiplier := accountFields[10].Descriptor()
|
||||
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||
// accountDescStatus is the schema descriptor for status field.
|
||||
accountDescStatus := accountFields[10].Descriptor()
|
||||
accountDescStatus := accountFields[11].Descriptor()
|
||||
// account.DefaultStatus holds the default value on creation for the status field.
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[15].Descriptor()
|
||||
accountDescSchedulable := accountFields[16].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[23].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[24].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
|
||||
@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
|
||||
field.Int("concurrency").
|
||||
Default(3),
|
||||
|
||||
field.Int("load_factor").Optional().Nillable(),
|
||||
|
||||
// priority: 账户优先级,数值越小优先级越高
|
||||
// 调度器会优先使用高优先级的账户
|
||||
field.Int("priority").
|
||||
|
||||
@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@@ -171,8 +173,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -182,7 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -398,8 +403,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
@@ -455,8 +458,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
|
||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -1335,7 +1335,7 @@ func setDefaults() {
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
case "off", "ctx_pool", "passthrough":
|
||||
case "shared", "dedicated":
|
||||
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
|
||||
@@ -102,6 +102,7 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -120,6 +121,7 @@ type UpdateAccountRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
@@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -240,52 +243,64 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
if !lite {
|
||||
// Get current concurrency counts for all accounts
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
|
||||
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
}
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
|
||||
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 窗口费用获取:lite 模式从快照缓存读取,非 lite 模式执行 PostgreSQL 查询后写入缓存
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
if lite {
|
||||
// lite 模式:尝试从快照缓存读取
|
||||
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
|
||||
if cached, ok := accountWindowCostCache.Get(cacheKey); ok {
|
||||
if costs, ok := cached.Payload.(map[int64]float64); ok {
|
||||
windowCosts = costs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
// 缓存未命中则 windowCosts 保持 nil(仅发生在服务刚启动时)
|
||||
} else {
|
||||
// 非 lite 模式:执行 PostgreSQL 聚合查询(高开销)
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
@@ -310,6 +325,10 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
|
||||
// 查询完毕后写入快照缓存,供 lite 模式使用
|
||||
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
|
||||
accountWindowCostCache.Set(cacheKey, windowCosts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,6 +525,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -575,6 +595,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -1101,6 +1122,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.LoadFactor != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -1119,6 +1141,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
@@ -1328,6 +1351,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// ResetQuota handles resetting account quota usage
|
||||
// POST /api/v1/admin/accounts/:id/reset-quota
|
||||
func (h *AccountHandler) ResetQuota(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
|
||||
response.InternalError(c, "Failed to reset account quota: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
|
||||
25
backend/internal/handler/admin/account_window_cost_cache.go
Normal file
25
backend/internal/handler/admin/account_window_cost_cache.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var accountWindowCostCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func buildWindowCostCacheKey(accountIDs []int64) string {
|
||||
if len(accountIDs) == 0 {
|
||||
return "accounts_window_cost_empty"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(accountIDs) * 6)
|
||||
_, _ = b.WriteString("accounts_window_cost:")
|
||||
for i, id := range accountIDs {
|
||||
if i > 0 {
|
||||
_ = b.WriteByte(',')
|
||||
}
|
||||
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||
type ScheduledTestHandler struct {
|
||||
scheduledTestSvc *service.ScheduledTestService
|
||||
}
|
||||
|
||||
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||
}
|
||||
|
||||
type createScheduledTestPlanRequest struct {
|
||||
AccountID int64 `json:"account_id" binding:"required"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid account id")
|
||||
return
|
||||
}
|
||||
|
||||
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, plans)
|
||||
}
|
||||
|
||||
// Create POST /admin/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
var req createScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan := &service.ScheduledTestPlan{
|
||||
AccountID: req.AccountID,
|
||||
ModelID: req.ModelID,
|
||||
CronExpression: req.CronExpression,
|
||||
Enabled: true,
|
||||
MaxResults: req.MaxResults,
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// Update PUT /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "plan not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelID != "" {
|
||||
existing.ModelID = req.ModelID
|
||||
}
|
||||
if req.CronExpression != "" {
|
||||
existing.CronExpression = req.CronExpression
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
@@ -819,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -905,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -183,6 +183,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
LoadFactor: a.LoadFactor,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
@@ -248,6 +249,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,7 @@ type Account struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
@@ -185,6 +186,10 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ type AdminHandlers struct {
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var handlerStructuredLogCaptureMu sync.Mutex
|
||||
|
||||
type handlerInMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||
for _, ev := range s.events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
handlerStructuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &handlerInMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
handlerStructuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
}
|
||||
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
compactStartedAt := time.Now()
|
||||
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||
if !isOpenAIRemoteCompactPath(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
path string
|
||||
status int
|
||||
)
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
if c.Request.URL != nil {
|
||||
path = strings.TrimSpace(c.Request.URL.Path)
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
status = c.Writer.Status()
|
||||
}
|
||||
}
|
||||
|
||||
outcome := "failed"
|
||||
if status >= 200 && status < 300 {
|
||||
outcome = "succeeded"
|
||||
}
|
||||
latencyMs := time.Since(startedAt).Milliseconds()
|
||||
if latencyMs < 0 {
|
||||
latencyMs = 0
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Bool("remote_compact", true),
|
||||
zap.String("compact_outcome", outcome),
|
||||
zap.Int("status_code", status),
|
||||
zap.Int64("latency_ms", latencyMs),
|
||||
zap.String("path", path),
|
||||
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if outcome == "succeeded" {
|
||||
log.Info("codex.remote_compact.succeeded")
|
||||
return
|
||||
}
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
|
||||
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||||
|
||||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||||
|
||||
@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
|
||||
@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -15,6 +15,7 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
|
||||
@@ -57,25 +57,28 @@ type DashboardStats struct {
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
|
||||
@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
} else {
|
||||
builder.ClearLoadFactor()
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -437,6 +445,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
switch status {
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
case "temp_unschedulable":
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}))
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
@@ -640,7 +656,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
SetStatus(service.StatusActive).
|
||||
SetErrorMessage("").
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
@@ -1205,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.RateMultiplier)
|
||||
idx++
|
||||
}
|
||||
if updates.LoadFactor != nil {
|
||||
if *updates.LoadFactor <= 0 {
|
||||
setClauses = append(setClauses, "load_factor = NULL")
|
||||
} else {
|
||||
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
|
||||
args = append(args, *updates.LoadFactor)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if updates.Status != nil {
|
||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||
args = append(args, *updates.Status)
|
||||
@@ -1527,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
LoadFactor: m.LoadFactor,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
@@ -1639,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 重置配额后触发调度快照刷新,使账号重新参与调度
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
|
||||
183
backend/internal/repository/scheduled_test_repo.go
Normal file
183
backend/internal/repository/scheduled_test_repo.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// --- Plan Repository ---
|
||||
|
||||
type scheduledTestPlanRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
|
||||
return &scheduledTestPlanRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
|
||||
`, id, lastRunAt, nextRunAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Result Repository ---
|
||||
|
||||
type scheduledTestResultRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
|
||||
return &scheduledTestResultRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
|
||||
|
||||
out := &service.ScheduledTestResult{}
|
||||
if err := row.Scan(
|
||||
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
|
||||
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`, planID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var results []*service.ScheduledTestResult
|
||||
for rows.Next() {
|
||||
r := &service.ScheduledTestResult{}
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
|
||||
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM scheduled_test_results
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
) ranked
|
||||
WHERE rn > $2
|
||||
)
|
||||
`, planID, keepCount)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- scan helpers ---
|
||||
|
||||
type scannable interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
|
||||
var plans []*service.ScheduledTestPlan
|
||||
for rows.Next() {
|
||||
p, err := scanPlan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plans = append(plans, p)
|
||||
}
|
||||
return plans, rows.Err()
|
||||
}
|
||||
@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1664,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1747,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1762,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1806,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
@@ -2622,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
@@ -2646,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
|
||||
@@ -125,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
@@ -144,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
|
||||
@@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
return int64(len(ids)), nil
|
||||
|
||||
@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
@@ -478,6 +482,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
plans := admin.Group("/scheduled-test-plans")
|
||||
{
|
||||
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||
}
|
||||
// Nested under accounts
|
||||
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
|
||||
@@ -28,6 +28,7 @@ type Account struct {
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) EffectiveLoadFactor() int {
|
||||
if a == nil {
|
||||
return 1
|
||||
}
|
||||
if a.LoadFactor != nil && *a.LoadFactor > 0 {
|
||||
return *a.LoadFactor
|
||||
}
|
||||
if a.Concurrency > 0 {
|
||||
return a.Concurrency
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -853,15 +867,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||
OpenAIWSIngressModePassthrough = "passthrough"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeCtxPool:
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
case OpenAIWSIngressModePassthrough:
|
||||
return OpenAIWSIngressModePassthrough
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
@@ -873,18 +893,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
// 4. defaultMode(非法时回退 ctx_pool)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
@@ -919,7 +942,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
return OpenAIWSIngressModeCtxPool, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
@@ -946,6 +969,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
@@ -1104,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
return "5m"
|
||||
}
|
||||
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
46
backend/internal/service/account_load_factor_test.go
Normal file
46
backend/internal/service/account_load_factor_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func intPtrHelper(v int) *int { return &v }
|
||||
|
||||
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
|
||||
require.Equal(t, 20, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
|
||||
require.Equal(t, 3, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||
shared := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
dedicated := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
|
||||
@@ -68,6 +68,10 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
|
||||
@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
|
||||
panic("unexpected BulkUpdate call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 false(账号不存在)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -33,7 +34,7 @@ import (
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
@@ -179,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
return s.routeAntigravityTest(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -1176,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
@@ -1560,3 +1573,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
// RunTestBackground executes an account test in-memory (no real HTTP client),
|
||||
// capturing SSE output via httptest.NewRecorder, then parses the result.
|
||||
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
|
||||
startedAt := time.Now()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
responseText, errMsg := parseTestSSEOutput(body)
|
||||
|
||||
status := "success"
|
||||
if testErr != nil || errMsg != "" {
|
||||
status = "failed"
|
||||
if errMsg == "" && testErr != nil {
|
||||
errMsg = testErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return &ScheduledTestResult{
|
||||
Status: status,
|
||||
ResponseText: responseText,
|
||||
ErrorMessage: errMsg,
|
||||
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
|
||||
StartedAt: startedAt,
|
||||
FinishedAt: finishedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTestSSEOutput extracts response text and error message from captured SSE output.
|
||||
func parseTestSSEOutput(body string) (responseText, errMsg string) {
|
||||
var texts []string
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
var event TestEvent
|
||||
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "content":
|
||||
if event.Text != "" {
|
||||
texts = append(texts, event.Text)
|
||||
}
|
||||
case "error":
|
||||
errMsg = event.Error
|
||||
}
|
||||
}
|
||||
responseText = strings.Join(texts, "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ type AdminService interface {
|
||||
DeleteRedeemCode(ctx context.Context, id int64) error
|
||||
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
ResetAccountQuota(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
@@ -195,6 +196,7 @@ type CreateAccountInput struct {
|
||||
Concurrency int
|
||||
Priority int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
@@ -215,6 +217,7 @@ type UpdateAccountInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
@@ -230,6 +233,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
@@ -1413,6 +1417,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil && *input.LoadFactor > 0 {
|
||||
if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
}
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1458,6 +1468,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -1483,6 +1497,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
if *input.LoadFactor <= 0 {
|
||||
account.LoadFactor = nil // 0 或负数表示清除
|
||||
} else if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
} else {
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
@@ -1616,6 +1639,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.RateMultiplier != nil {
|
||||
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
if *input.LoadFactor <= 0 {
|
||||
repoUpdates.LoadFactor = nil // 0 或负数表示清除
|
||||
} else if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
} else {
|
||||
repoUpdates.LoadFactor = input.LoadFactor
|
||||
}
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
@@ -2439,3 +2471,7 @@ func (e *MixedChannelError) Error() string {
|
||||
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
|
||||
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||
}
|
||||
|
||||
@@ -43,15 +43,24 @@ type BillingCache interface {
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
openAIGPT54LongContextInputThreshold = 272000
|
||||
openAIGPT54LongContextInputMultiplier = 2.0
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
@@ -161,6 +170,35 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
@@ -189,12 +227,30 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
// Claude 未知型号统一回退到 Sonnet,避免计费中断。
|
||||
if strings.Contains(modelLower, "claude") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
|
||||
return s.fallbackPrices["gemini-3.1-pro"]
|
||||
}
|
||||
|
||||
// 默认使用Sonnet价格
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
|
||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||
normalized := normalizeCodexModel(modelLower)
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
return s.fallbackPrices["gpt-5.1-codex"]
|
||||
case "gpt-5.1":
|
||||
return s.fallbackPrices["gpt-5.1"]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格配置
|
||||
@@ -212,15 +268,18 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
}, nil
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,7 +287,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
fallback := s.getFallbackPricing(model)
|
||||
if fallback != nil {
|
||||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||||
return fallback, nil
|
||||
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
@@ -242,12 +301,18 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
@@ -279,6 +344,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
if !isOpenAIGPT54Model(model) {
|
||||
return pricing
|
||||
}
|
||||
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
|
||||
return pricing
|
||||
}
|
||||
cloned := *pricing
|
||||
if cloned.LongContextInputThreshold <= 0 {
|
||||
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
|
||||
}
|
||||
if cloned.LongContextInputMultiplier <= 0 {
|
||||
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
|
||||
}
|
||||
if cloned.LongContextOutputMultiplier <= 0 {
|
||||
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
|
||||
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
|
||||
return false
|
||||
}
|
||||
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
|
||||
return false
|
||||
}
|
||||
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
|
||||
return totalInputTokens > pricing.LongContextInputThreshold
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
|
||||
return normalized == "gpt-5.4"
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
|
||||
@@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) {
|
||||
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
|
||||
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
|
||||
@@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-unknown-model")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pricing)
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 4000,
|
||||
}
|
||||
|
||||
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
|
||||
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expectedInput float64
|
||||
expectNilPricing bool
|
||||
}{
|
||||
{name: "empty model", model: " ", expectNilPricing: true},
|
||||
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
|
||||
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
|
||||
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
|
||||
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
|
||||
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
|
||||
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
|
||||
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
|
||||
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
|
||||
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
|
||||
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
|
||||
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
|
||||
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pricing := svc.getFallbackPricing(tt.model)
|
||||
if tt.expectNilPricing {
|
||||
require.Nil(t, pricing)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
|
||||
@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 14,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
|
||||
@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
|
||||
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
|
||||
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.True(t, ok)
|
||||
bodyBytes, ok := rawBody.([]byte)
|
||||
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
||||
require.Equal(t, body, bodyBytes)
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
||||
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
|
||||
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
||||
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
modelMapping map[string]any // nil = 不配置映射
|
||||
expectedModel string
|
||||
endpoint string // "messages" or "count_tokens"
|
||||
}{
|
||||
{
|
||||
name: "Forward: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 空映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: tt.model,
|
||||
}
|
||||
|
||||
credentials := map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
}
|
||||
if tt.modelMapping != nil {
|
||||
credentials["model_mapping"] = tt.modelMapping
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Name: "edge-case-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: credentials,
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
if tt.endpoint == "messages" {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
parsed.Stream = false
|
||||
|
||||
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
} else {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
|
||||
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
// 包含复杂字段的请求体:system、thinking、messages
|
||||
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Name: "preserve-fields-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
sentBody := upstream.lastBody
|
||||
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
|
||||
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
|
||||
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
|
||||
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
|
||||
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
|
||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||
// 确保空模型名不会触发映射逻辑
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "", // 空模型
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":10}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Name: "empty-model-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
// 空模型名时,body 应原样透传,不应触发映射
|
||||
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
|
||||
|
||||
|
||||
@@ -1228,6 +1228,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(account) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||
filteredWindowCost++
|
||||
@@ -1260,6 +1264,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1311,7 +1316,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range routingCandidates {
|
||||
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||
@@ -1416,6 +1421,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1480,6 +1486,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
@@ -1499,7 +1509,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2113,6 +2123,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
@@ -2590,7 +2609,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2644,6 +2663,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2700,7 +2722,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2743,6 +2765,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2818,7 +2843,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
@@ -2874,6 +2899,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2930,7 +2958,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -2975,6 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -3889,7 +3920,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
if passthroughModel != "" {
|
||||
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
passthroughModel = mappedModel
|
||||
}
|
||||
}
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -4574,7 +4614,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -4954,7 +4994,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6370,6 +6410,89 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
User *User
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
deferredService *DeferredService
|
||||
}
|
||||
|
||||
func (s *GatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6533,45 +6656,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// 更新 API Key 配额(如果设置了配额限制)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6731,44 +6830,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
// API Key 独立配额扣费
|
||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6781,7 +6857,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
||||
passthroughBody := parsed.Body
|
||||
if reqModel := parsed.Model; reqModel != "" {
|
||||
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -7072,7 +7155,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -7119,7 +7202,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
|
||||
account: &Account{
|
||||
ID: 105,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
|
||||
@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ import (
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
|
||||
// 匹配 user_id 格式:
|
||||
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
|
||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
|
||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 匹配格式: user_{64位hex}_account__session_{uuid}
|
||||
// 匹配格式:
|
||||
// 旧格式: user_{64位hex}_account__session_{uuid}
|
||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
|
||||
matches := userIDRegex.FindStringSubmatch(userID)
|
||||
if matches == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionTail := matches[1] // 原始session UUID
|
||||
// matches[1] = account UUID (可能为空), matches[2] = session UUID
|
||||
sessionTail := matches[2] // 原始session UUID
|
||||
|
||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
|
||||
@@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
if s.service.concurrencyService != nil {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
filtered = append(filtered, account)
|
||||
loadReq = append(loadReq, AccountWithConcurrency{
|
||||
ID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
MaxConcurrency: account.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
@@ -703,6 +704,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
|
||||
@@ -9,6 +9,13 @@ import (
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt-5.4-none": "gpt-5.4",
|
||||
"gpt-5.4-low": "gpt-5.4",
|
||||
"gpt-5.4-medium": "gpt-5.4",
|
||||
"gpt-5.4-high": "gpt-5.4",
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
"gpt-5.4-chat-latest": "gpt-5.4",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-none": "gpt-5.3-codex",
|
||||
"gpt-5.3-low": "gpt-5.3-codex",
|
||||
@@ -154,6 +161,9 @@ func normalizeCodexModel(model string) string {
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||
return "gpt-5.2-codex"
|
||||
}
|
||||
|
||||
@@ -167,6 +167,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt-5.4-high": "gpt-5.4",
|
||||
"gpt-5.4-chat-latest": "gpt-5.4",
|
||||
"gpt 5.4": "gpt-5.4",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
|
||||
@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
openaiSchedulerOnce sync.Once
|
||||
openaiWSPool *openAIWSConnPool
|
||||
openaiWSStateStore OpenAIWSStateStore
|
||||
openaiScheduler OpenAIAccountScheduler
|
||||
openaiAccountStats *openAIAccountRuntimeStats
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
openaiSchedulerOnce sync.Once
|
||||
openaiWSPassthroughDialerOnce sync.Once
|
||||
openaiWSPool *openAIWSConnPool
|
||||
openaiWSStateStore OpenAIWSStateStore
|
||||
openaiScheduler OpenAIAccountScheduler
|
||||
openaiWSPassthroughDialer openAIWSClientDialer
|
||||
openaiAccountStats *openAIAccountRuntimeStats
|
||||
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
@@ -317,6 +319,16 @@ func NewOpenAIGatewayService(
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
}
|
||||
}
|
||||
|
||||
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
|
||||
// 应在应用优雅关闭时调用。
|
||||
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
|
||||
@@ -1240,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3472,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// Update API key quota if applicable (only for balance mode with quota set)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
|
||||
|
||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if c == nil || c.conn == nil {
|
||||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
msgType, payload, err := c.conn.Read(ctx)
|
||||
if err != nil {
|
||||
return coderws.MessageText, nil, err
|
||||
}
|
||||
return msgType, payload, nil
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Write(ctx, msgType, payload)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
|
||||
@@ -46,9 +46,10 @@ const (
|
||||
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
||||
openAIWSPayloadSizeEstimateMaxItems = 16
|
||||
|
||||
openAIWSEventFlushBatchSizeDefault = 4
|
||||
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
||||
openAIWSPayloadLogSampleDefault = 0.2
|
||||
openAIWSEventFlushBatchSizeDefault = 4
|
||||
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
||||
openAIWSPayloadLogSampleDefault = 0.2
|
||||
openAIWSPassthroughIdleTimeoutDefault = time.Hour
|
||||
|
||||
openAIWSStoreDisabledConnModeStrict = "strict"
|
||||
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
||||
@@ -863,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
|
||||
strings.Contains(message, "unexpected eof") ||
|
||||
strings.Contains(message, "use of closed network connection") ||
|
||||
strings.Contains(message, "connection reset by peer") ||
|
||||
strings.Contains(message, "broken pipe")
|
||||
strings.Contains(message, "broken pipe") ||
|
||||
strings.Contains(message, "an established connection was aborted")
|
||||
}
|
||||
|
||||
func classifyOpenAIWSReadFallbackReason(err error) string {
|
||||
@@ -904,6 +906,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
|
||||
return s.openaiWSPool
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.openaiWSPassthroughDialerOnce.Do(func() {
|
||||
if s.openaiWSPassthroughDialer == nil {
|
||||
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
|
||||
}
|
||||
})
|
||||
return s.openaiWSPassthroughDialer
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
||||
pool := s.getOpenAIWSConnPool()
|
||||
if pool == nil {
|
||||
@@ -967,6 +981,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
|
||||
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
|
||||
return timeout
|
||||
}
|
||||
return openAIWSPassthroughIdleTimeoutDefault
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
||||
@@ -2322,7 +2343,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||
ingressMode := OpenAIWSIngressModeShared
|
||||
ingressMode := OpenAIWSIngressModeCtxPool
|
||||
if modeRouterV2Enabled {
|
||||
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
||||
if ingressMode == OpenAIWSIngressModeOff {
|
||||
@@ -2332,6 +2353,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
nil,
|
||||
)
|
||||
}
|
||||
switch ingressMode {
|
||||
case OpenAIWSIngressModePassthrough:
|
||||
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||
}
|
||||
return s.proxyResponsesWebSocketV2Passthrough(
|
||||
ctx,
|
||||
c,
|
||||
clientConn,
|
||||
account,
|
||||
token,
|
||||
firstClientMessage,
|
||||
hooks,
|
||||
wsDecision,
|
||||
)
|
||||
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||
// continue
|
||||
default:
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"websocket mode only supports ctx_pool/passthrough",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||
|
||||
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
||||
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
|
||||
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
|
||||
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
upstreamConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPassthroughDialer: captureDialer,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 452,
|
||||
Name: "openai-ingress-passthrough",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
resultCh := make(chan *OpenAIForwardResult, 1)
|
||||
hooks := &OpenAIWSIngressHooks{
|
||||
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
|
||||
if turnErr == nil && result != nil {
|
||||
resultCh <- result
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 passthrough websocket 结束超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
|
||||
require.True(t, result.OpenAIWSMode)
|
||||
require.Equal(t, 2, result.Usage.InputTokens)
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
|
||||
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
payload, err := c.ReadMessage(ctx)
|
||||
if err != nil {
|
||||
return coderws.MessageText, nil, err
|
||||
}
|
||||
return coderws.MessageText, payload, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
|
||||
return c.WriteJSON(ctx, json.RawMessage(payload))
|
||||
}
|
||||
|
||||
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
||||
_ = ctx
|
||||
return nil
|
||||
|
||||
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
|
||||
switch mode {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
|
||||
// continue
|
||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||
// 历史值兼容:按 ctx_pool 处理。
|
||||
mode = OpenAIWSIngressModeCtxPool
|
||||
default:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
}
|
||||
|
||||
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
|
||||
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
|
||||
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("off mode routes to http", func(t *testing.T) {
|
||||
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
require.Equal(t, "account_mode_off", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
|
||||
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
|
||||
legacyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
|
||||
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
|
||||
passthroughAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
||||
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
||||
|
||||
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
|
||||
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
|
||||
//
|
||||
// Reference:
|
||||
// - Project: caddyserver/caddy (Apache-2.0)
|
||||
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
|
||||
// - Files:
|
||||
// - modules/caddyhttp/reverseproxy/streaming.go
|
||||
// - modules/caddyhttp/reverseproxy/reverseproxy.go
|
||||
func runCaddyStyleRelay(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
upstreamConn FrameConn,
|
||||
firstClientMessage []byte,
|
||||
options RelayOptions,
|
||||
) (RelayResult, *RelayExit) {
|
||||
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
|
||||
}
|
||||
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import "context"
|
||||
|
||||
// EntryInput 是 passthrough v2 数据面的入口参数。
|
||||
type EntryInput struct {
|
||||
Ctx context.Context
|
||||
ClientConn FrameConn
|
||||
UpstreamConn FrameConn
|
||||
FirstClientMessage []byte
|
||||
Options RelayOptions
|
||||
}
|
||||
|
||||
// RunEntry 是 openai_ws_v2 包对外的统一入口。
|
||||
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
|
||||
return runCaddyStyleRelay(
|
||||
input.Ctx,
|
||||
input.ClientConn,
|
||||
input.UpstreamConn,
|
||||
input.FirstClientMessage,
|
||||
input.Options,
|
||||
)
|
||||
}
|
||||
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
|
||||
type MetricsSnapshot struct {
|
||||
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
|
||||
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
|
||||
}
|
||||
|
||||
var (
|
||||
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
|
||||
passthroughSemanticMutationTotal atomic.Int64
|
||||
passthroughUsageParseFailureTotal atomic.Int64
|
||||
)
|
||||
|
||||
func recordUsageParseFailure() {
|
||||
passthroughUsageParseFailureTotal.Add(1)
|
||||
}
|
||||
|
||||
// SnapshotMetrics 返回当前 passthrough 指标快照。
|
||||
func SnapshotMetrics() MetricsSnapshot {
|
||||
return MetricsSnapshot{
|
||||
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
|
||||
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
|
||||
}
|
||||
}
|
||||
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
@@ -0,0 +1,807 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type FrameConn interface {
|
||||
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
||||
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationInputTokens int
|
||||
CacheReadInputTokens int
|
||||
}
|
||||
|
||||
type RelayResult struct {
|
||||
RequestModel string
|
||||
Usage Usage
|
||||
RequestID string
|
||||
TerminalEventType string
|
||||
FirstTokenMs *int
|
||||
Duration time.Duration
|
||||
ClientToUpstreamFrames int64
|
||||
UpstreamToClientFrames int64
|
||||
DroppedDownstreamFrames int64
|
||||
}
|
||||
|
||||
type RelayTurnResult struct {
|
||||
RequestModel string
|
||||
Usage Usage
|
||||
RequestID string
|
||||
TerminalEventType string
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
type RelayExit struct {
|
||||
Stage string
|
||||
Err error
|
||||
WroteDownstream bool
|
||||
}
|
||||
|
||||
type RelayOptions struct {
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
UpstreamDrainTimeout time.Duration
|
||||
FirstMessageType coderws.MessageType
|
||||
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||
OnTurnComplete func(turn RelayTurnResult)
|
||||
OnTrace func(event RelayTraceEvent)
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
type RelayTraceEvent struct {
|
||||
Stage string
|
||||
Direction string
|
||||
MessageType string
|
||||
PayloadBytes int
|
||||
Graceful bool
|
||||
WroteDownstream bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type relayState struct {
|
||||
usage Usage
|
||||
requestModel string
|
||||
lastResponseID string
|
||||
terminalEventType string
|
||||
firstTokenMs *int
|
||||
turnTimingByID map[string]*relayTurnTiming
|
||||
}
|
||||
|
||||
type relayExitSignal struct {
|
||||
stage string
|
||||
err error
|
||||
graceful bool
|
||||
wroteDownstream bool
|
||||
}
|
||||
|
||||
type observedUpstreamEvent struct {
|
||||
terminal bool
|
||||
eventType string
|
||||
responseID string
|
||||
usage Usage
|
||||
duration time.Duration
|
||||
firstToken *int
|
||||
}
|
||||
|
||||
type relayTurnTiming struct {
|
||||
startAt time.Time
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func Relay(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
upstreamConn FrameConn,
|
||||
firstClientMessage []byte,
|
||||
options RelayOptions,
|
||||
) (RelayResult, *RelayExit) {
|
||||
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
||||
if clientConn == nil || upstreamConn == nil {
|
||||
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
nowFn := options.Now
|
||||
if nowFn == nil {
|
||||
nowFn = time.Now
|
||||
}
|
||||
writeTimeout := options.WriteTimeout
|
||||
if writeTimeout <= 0 {
|
||||
writeTimeout = 2 * time.Minute
|
||||
}
|
||||
drainTimeout := options.UpstreamDrainTimeout
|
||||
if drainTimeout <= 0 {
|
||||
drainTimeout = 1200 * time.Millisecond
|
||||
}
|
||||
firstMessageType := options.FirstMessageType
|
||||
if firstMessageType != coderws.MessageBinary {
|
||||
firstMessageType = coderws.MessageText
|
||||
}
|
||||
startAt := nowFn()
|
||||
state := &relayState{requestModel: result.RequestModel}
|
||||
onTrace := options.OnTrace
|
||||
|
||||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||
defer relayCancel()
|
||||
|
||||
lastActivity := atomic.Int64{}
|
||||
lastActivity.Store(nowFn().UnixNano())
|
||||
markActivity := func() {
|
||||
lastActivity.Store(nowFn().UnixNano())
|
||||
}
|
||||
|
||||
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
||||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||
defer cancel()
|
||||
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
||||
}
|
||||
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
||||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||
defer cancel()
|
||||
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
||||
}
|
||||
|
||||
clientToUpstreamFrames := &atomic.Int64{}
|
||||
upstreamToClientFrames := &atomic.Int64{}
|
||||
droppedDownstreamFrames := &atomic.Int64{}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_start",
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
})
|
||||
|
||||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||
result.Duration = nowFn().Sub(startAt)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_failed",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
Error: err.Error(),
|
||||
})
|
||||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||
}
|
||||
clientToUpstreamFrames.Add(1)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_ok",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
})
|
||||
markActivity()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 3)
|
||||
dropDownstreamWrites := atomic.Bool{}
|
||||
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||
go runUpstreamToClient(
|
||||
relayCtx,
|
||||
upstreamConn,
|
||||
writeClient,
|
||||
startAt,
|
||||
nowFn,
|
||||
state,
|
||||
options.OnUsageParseFailure,
|
||||
options.OnTurnComplete,
|
||||
&dropDownstreamWrites,
|
||||
upstreamToClientFrames,
|
||||
droppedDownstreamFrames,
|
||||
markActivity,
|
||||
onTrace,
|
||||
exitCh,
|
||||
)
|
||||
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
||||
|
||||
firstExit := <-exitCh
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "first_exit",
|
||||
Direction: relayDirectionFromStage(firstExit.stage),
|
||||
Graceful: firstExit.graceful,
|
||||
WroteDownstream: firstExit.wroteDownstream,
|
||||
Error: relayErrorString(firstExit.err),
|
||||
})
|
||||
combinedWroteDownstream := firstExit.wroteDownstream
|
||||
secondExit := relayExitSignal{graceful: true}
|
||||
hasSecondExit := false
|
||||
|
||||
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||
dropDownstreamWrites.Store(true)
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
||||
} else {
|
||||
relayCancel()
|
||||
_ = upstreamConn.Close()
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||
}
|
||||
if hasSecondExit {
|
||||
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "second_exit",
|
||||
Direction: relayDirectionFromStage(secondExit.stage),
|
||||
Graceful: secondExit.graceful,
|
||||
WroteDownstream: secondExit.wroteDownstream,
|
||||
Error: relayErrorString(secondExit.err),
|
||||
})
|
||||
}
|
||||
|
||||
relayCancel()
|
||||
_ = upstreamConn.Close()
|
||||
|
||||
enrichResult(&result, state, nowFn().Sub(startAt))
|
||||
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||
stage := "client_disconnected"
|
||||
exitErr := firstExit.err
|
||||
if hasSecondExit && !secondExit.graceful {
|
||||
stage = secondExit.stage
|
||||
exitErr = secondExit.err
|
||||
}
|
||||
if exitErr == nil {
|
||||
exitErr = io.EOF
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(exitErr),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: stage,
|
||||
Err: exitErr,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_complete",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
_ = clientConn.Close()
|
||||
return result, nil
|
||||
}
|
||||
if !firstExit.graceful {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(firstExit.stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(firstExit.err),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: firstExit.stage,
|
||||
Err: firstExit.err,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
if hasSecondExit && !secondExit.graceful {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(secondExit.stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(secondExit.err),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: secondExit.stage,
|
||||
Err: secondExit.err,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_complete",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
_ = clientConn.Close()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func runClientToUpstream(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||
markActivity func(),
|
||||
forwardedFrames *atomic.Int64,
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
for {
|
||||
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "read_client_failed",
|
||||
Direction: "client_to_upstream",
|
||||
Error: err.Error(),
|
||||
Graceful: isDisconnectError(err),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
if err := writeUpstream(msgType, payload); err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_upstream_failed",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
Error: err.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
||||
return
|
||||
}
|
||||
if forwardedFrames != nil {
|
||||
forwardedFrames.Add(1)
|
||||
}
|
||||
markActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func runUpstreamToClient(
|
||||
ctx context.Context,
|
||||
upstreamConn FrameConn,
|
||||
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
||||
startAt time.Time,
|
||||
nowFn func() time.Time,
|
||||
state *relayState,
|
||||
onUsageParseFailure func(eventType string, usageRaw string),
|
||||
onTurnComplete func(turn RelayTurnResult),
|
||||
dropDownstreamWrites *atomic.Bool,
|
||||
forwardedFrames *atomic.Int64,
|
||||
droppedFrames *atomic.Int64,
|
||||
markActivity func(),
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
wroteDownstream := false
|
||||
for {
|
||||
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "read_upstream_failed",
|
||||
Direction: "upstream_to_client",
|
||||
Error: err.Error(),
|
||||
Graceful: isDisconnectError(err),
|
||||
WroteDownstream: wroteDownstream,
|
||||
})
|
||||
exitCh <- relayExitSignal{
|
||||
stage: "read_upstream",
|
||||
err: err,
|
||||
graceful: isDisconnectError(err),
|
||||
wroteDownstream: wroteDownstream,
|
||||
}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
observedEvent := observedUpstreamEvent{}
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
||||
case coderws.MessageBinary:
|
||||
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
||||
}
|
||||
emitTurnComplete(onTurnComplete, state, observedEvent)
|
||||
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
||||
if droppedFrames != nil {
|
||||
droppedFrames.Add(1)
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "drop_downstream_frame",
|
||||
Direction: "upstream_to_client",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
WroteDownstream: wroteDownstream,
|
||||
})
|
||||
if observedEvent.terminal {
|
||||
exitCh <- relayExitSignal{
|
||||
stage: "drain_terminal",
|
||||
graceful: true,
|
||||
wroteDownstream: wroteDownstream,
|
||||
}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
continue
|
||||
}
|
||||
if err := writeClient(msgType, payload); err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_client_failed",
|
||||
Direction: "upstream_to_client",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
WroteDownstream: wroteDownstream,
|
||||
Error: err.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
||||
return
|
||||
}
|
||||
wroteDownstream = true
|
||||
if forwardedFrames != nil {
|
||||
forwardedFrames.Add(1)
|
||||
}
|
||||
markActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func runIdleWatchdog(
|
||||
ctx context.Context,
|
||||
nowFn func() time.Time,
|
||||
idleTimeout time.Duration,
|
||||
lastActivity *atomic.Int64,
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
if idleTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
||||
if checkInterval < time.Second {
|
||||
checkInterval = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if nowFn().Sub(last) < idleTimeout {
|
||||
continue
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "idle_timeout_triggered",
|
||||
Direction: "watchdog",
|
||||
Error: context.DeadlineExceeded.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
||||
if onTrace == nil {
|
||||
return
|
||||
}
|
||||
onTrace(event)
|
||||
}
|
||||
|
||||
func relayMessageTypeString(msgType coderws.MessageType) string {
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
return "text"
|
||||
case coderws.MessageBinary:
|
||||
return "binary"
|
||||
default:
|
||||
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
||||
}
|
||||
}
|
||||
|
||||
func relayDirectionFromStage(stage string) string {
|
||||
switch stage {
|
||||
case "read_client", "write_upstream":
|
||||
return "client_to_upstream"
|
||||
case "read_upstream", "write_client", "drain_terminal":
|
||||
return "upstream_to_client"
|
||||
case "idle_timeout":
|
||||
return "watchdog"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func relayErrorString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func observeUpstreamMessage(
|
||||
state *relayState,
|
||||
message []byte,
|
||||
startAt time.Time,
|
||||
nowFn func() time.Time,
|
||||
onUsageParseFailure func(eventType string, usageRaw string),
|
||||
) observedUpstreamEvent {
|
||||
if state == nil || len(message) == 0 {
|
||||
return observedUpstreamEvent{}
|
||||
}
|
||||
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
||||
eventType := strings.TrimSpace(values[0].String())
|
||||
if eventType == "" {
|
||||
return observedUpstreamEvent{}
|
||||
}
|
||||
responseID := strings.TrimSpace(values[1].String())
|
||||
if responseID == "" {
|
||||
responseID = strings.TrimSpace(values[2].String())
|
||||
}
|
||||
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
||||
if responseID == "" && isTerminalEvent(eventType) {
|
||||
responseID = strings.TrimSpace(values[3].String())
|
||||
}
|
||||
now := nowFn()
|
||||
|
||||
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||
ms := int(now.Sub(startAt).Milliseconds())
|
||||
if ms >= 0 {
|
||||
state.firstTokenMs = &ms
|
||||
}
|
||||
}
|
||||
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
||||
observed := observedUpstreamEvent{
|
||||
eventType: eventType,
|
||||
responseID: responseID,
|
||||
usage: parsedUsage,
|
||||
}
|
||||
if responseID != "" {
|
||||
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
||||
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
||||
if ms >= 0 {
|
||||
turnTiming.firstTokenMs = &ms
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isTerminalEvent(eventType) {
|
||||
return observed
|
||||
}
|
||||
observed.terminal = true
|
||||
state.terminalEventType = eventType
|
||||
if responseID != "" {
|
||||
state.lastResponseID = responseID
|
||||
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
||||
duration := now.Sub(turnTiming.startAt)
|
||||
if duration < 0 {
|
||||
duration = 0
|
||||
}
|
||||
observed.duration = duration
|
||||
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
||||
}
|
||||
}
|
||||
return observed
|
||||
}
|
||||
|
||||
func emitTurnComplete(
|
||||
onTurnComplete func(turn RelayTurnResult),
|
||||
state *relayState,
|
||||
observed observedUpstreamEvent,
|
||||
) {
|
||||
if onTurnComplete == nil || !observed.terminal {
|
||||
return
|
||||
}
|
||||
responseID := strings.TrimSpace(observed.responseID)
|
||||
if responseID == "" {
|
||||
return
|
||||
}
|
||||
requestModel := ""
|
||||
if state != nil {
|
||||
requestModel = state.requestModel
|
||||
}
|
||||
onTurnComplete(RelayTurnResult{
|
||||
RequestModel: requestModel,
|
||||
Usage: observed.usage,
|
||||
RequestID: responseID,
|
||||
TerminalEventType: observed.eventType,
|
||||
Duration: observed.duration,
|
||||
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
||||
})
|
||||
}
|
||||
|
||||
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if state.turnTimingByID == nil {
|
||||
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
||||
}
|
||||
timing, ok := state.turnTimingByID[responseID]
|
||||
if !ok || timing == nil || timing.startAt.IsZero() {
|
||||
timing = &relayTurnTiming{startAt: now}
|
||||
state.turnTimingByID[responseID] = timing
|
||||
return timing
|
||||
}
|
||||
return timing
|
||||
}
|
||||
|
||||
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
||||
if state == nil || state.turnTimingByID == nil {
|
||||
return relayTurnTiming{}, false
|
||||
}
|
||||
timing, ok := state.turnTimingByID[responseID]
|
||||
if !ok || timing == nil {
|
||||
return relayTurnTiming{}, false
|
||||
}
|
||||
delete(state.turnTimingByID, responseID)
|
||||
return *timing, true
|
||||
}
|
||||
|
||||
func openAIWSRelayCloneIntPtr(v *int) *int {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *v
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func parseUsageAndAccumulate(
|
||||
state *relayState,
|
||||
message []byte,
|
||||
eventType string,
|
||||
onParseFailure func(eventType string, usageRaw string),
|
||||
) Usage {
|
||||
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
||||
return Usage{}
|
||||
}
|
||||
usageResult := gjson.GetBytes(message, "response.usage")
|
||||
if !usageResult.Exists() {
|
||||
return Usage{}
|
||||
}
|
||||
usageRaw := strings.TrimSpace(usageResult.Raw)
|
||||
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
||||
recordUsageParseFailure()
|
||||
if onParseFailure != nil {
|
||||
onParseFailure(eventType, usageRaw)
|
||||
}
|
||||
return Usage{}
|
||||
}
|
||||
|
||||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||
|
||||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
||||
if !inputOK || !outputOK || !cachedOK {
|
||||
recordUsageParseFailure()
|
||||
if onParseFailure != nil {
|
||||
onParseFailure(eventType, usageRaw)
|
||||
}
|
||||
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
||||
return Usage{}
|
||||
}
|
||||
parsedUsage := Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
}
|
||||
|
||||
state.usage.InputTokens += parsedUsage.InputTokens
|
||||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||
return parsedUsage
|
||||
}
|
||||
|
||||
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
||||
if !value.Exists() {
|
||||
return 0, !required
|
||||
}
|
||||
if value.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return int(value.Int()), true
|
||||
}
|
||||
|
||||
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
result.Duration = duration
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
result.RequestModel = state.requestModel
|
||||
result.Usage = state.usage
|
||||
result.RequestID = state.lastResponseID
|
||||
result.TerminalEventType = state.terminalEventType
|
||||
result.FirstTokenMs = state.firstTokenMs
|
||||
}
|
||||
|
||||
func isDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
switch coderws.CloseStatus(err) {
|
||||
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||||
return true
|
||||
}
|
||||
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(message, "failed to read frame header: eof") ||
|
||||
strings.Contains(message, "unexpected eof") ||
|
||||
strings.Contains(message, "use of closed network connection") ||
|
||||
strings.Contains(message, "connection reset by peer") ||
|
||||
strings.Contains(message, "broken pipe")
|
||||
}
|
||||
|
||||
func isTerminalEvent(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func shouldParseUsage(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isTokenEvent(eventType string) bool {
|
||||
if eventType == "" {
|
||||
return false
|
||||
}
|
||||
switch eventType {
|
||||
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||||
return false
|
||||
}
|
||||
if strings.Contains(eventType, ".delta") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(eventType, "response.output_text") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(eventType, "response.output") {
|
||||
return true
|
||||
}
|
||||
return eventType == "response.completed" || eventType == "response.done"
|
||||
}
|
||||
|
||||
func minDuration(a, b time.Duration) time.Duration {
|
||||
if a <= 0 {
|
||||
return b
|
||||
}
|
||||
if b <= 0 {
|
||||
return a
|
||||
}
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
||||
if timeout <= 0 {
|
||||
timeout = 200 * time.Millisecond
|
||||
}
|
||||
select {
|
||||
case sig := <-exitCh:
|
||||
return sig, true
|
||||
case <-time.After(timeout):
|
||||
return relayExitSignal{}, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestRunEntry_DelegatesRelay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
result, relayExit := RunEntry(EntryInput{
|
||||
Ctx: context.Background(),
|
||||
ClientConn: clientConn,
|
||||
UpstreamConn: upstreamConn,
|
||||
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
|
||||
})
|
||||
require.Nil(t, relayExit)
|
||||
require.Equal(t, "resp_entry", result.RequestID)
|
||||
}
|
||||
|
||||
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("read client eof", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
runClientToUpstream(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn(nil, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
func() {},
|
||||
nil,
|
||||
nil,
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "read_client", sig.stage)
|
||||
require.True(t, sig.graceful)
|
||||
})
|
||||
|
||||
t.Run("write upstream failed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
runClientToUpstream(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||
}, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
||||
func() {},
|
||||
nil,
|
||||
nil,
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "write_upstream", sig.stage)
|
||||
require.False(t, sig.graceful)
|
||||
})
|
||||
|
||||
t.Run("forwarded counter and trace callback", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
forwarded := &atomic.Int64{}
|
||||
traces := make([]RelayTraceEvent, 0, 2)
|
||||
runClientToUpstream(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||
}, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
func() {},
|
||||
forwarded,
|
||||
func(event RelayTraceEvent) {
|
||||
traces = append(traces, event)
|
||||
},
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "read_client", sig.stage)
|
||||
require.Equal(t, int64(1), forwarded.Load())
|
||||
require.NotEmpty(t, traces)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("read upstream eof", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
drop := &atomic.Bool{}
|
||||
drop.Store(false)
|
||||
runUpstreamToClient(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn(nil, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
time.Now(),
|
||||
time.Now,
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
nil,
|
||||
func() {},
|
||||
nil,
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "read_upstream", sig.stage)
|
||||
require.True(t, sig.graceful)
|
||||
})
|
||||
|
||||
t.Run("write client failed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
drop := &atomic.Bool{}
|
||||
drop.Store(false)
|
||||
runUpstreamToClient(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
|
||||
}, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
|
||||
time.Now(),
|
||||
time.Now,
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
nil,
|
||||
func() {},
|
||||
nil,
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "write_client", sig.stage)
|
||||
})
|
||||
|
||||
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
drop := &atomic.Bool{}
|
||||
drop.Store(true)
|
||||
dropped := &atomic.Int64{}
|
||||
runUpstreamToClient(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}, true),
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
time.Now(),
|
||||
time.Now,
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
dropped,
|
||||
func() {},
|
||||
nil,
|
||||
exitCh,
|
||||
)
|
||||
sig := <-exitCh
|
||||
require.Equal(t, "drain_terminal", sig.stage)
|
||||
require.True(t, sig.graceful)
|
||||
require.Equal(t, int64(1), dropped.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 1)
|
||||
lastActivity := &atomic.Int64{}
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
|
||||
select {
|
||||
case <-exitCh:
|
||||
t.Fatal("unexpected idle timeout signal")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperFunctionsCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
|
||||
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
|
||||
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
|
||||
|
||||
require.Equal(t, "", relayErrorString(nil))
|
||||
require.Equal(t, "x", relayErrorString(errors.New("x")))
|
||||
|
||||
require.True(t, isDisconnectError(io.EOF))
|
||||
require.True(t, isDisconnectError(net.ErrClosed))
|
||||
require.True(t, isDisconnectError(context.Canceled))
|
||||
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
|
||||
require.True(t, isDisconnectError(errors.New("broken pipe")))
|
||||
require.False(t, isDisconnectError(errors.New("unrelated")))
|
||||
|
||||
require.True(t, isTokenEvent("response.output_text.delta"))
|
||||
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||
require.True(t, isTokenEvent("response.completed"))
|
||||
require.False(t, isTokenEvent(""))
|
||||
require.False(t, isTokenEvent("response.created"))
|
||||
|
||||
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
|
||||
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
|
||||
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
|
||||
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
|
||||
|
||||
ch := make(chan relayExitSignal, 1)
|
||||
ch <- relayExitSignal{stage: "ok"}
|
||||
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ok", sig.stage)
|
||||
ch <- relayExitSignal{stage: "ok2"}
|
||||
sig, ok = waitRelayExit(ch, 0)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ok2", sig.stage)
|
||||
_, ok = waitRelayExit(ch, 10*time.Millisecond)
|
||||
require.False(t, ok)
|
||||
|
||||
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 3, n)
|
||||
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
|
||||
require.False(t, ok)
|
||||
n, ok = parseUsageIntField(gjson.Result{}, false)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 0, n)
|
||||
_, ok = parseUsageIntField(gjson.Result{}, true)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &relayState{}
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
|
||||
require.Equal(t, 0, state.usage.InputTokens)
|
||||
|
||||
parseUsageAndAccumulate(
|
||||
state,
|
||||
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
|
||||
"response.completed",
|
||||
nil,
|
||||
)
|
||||
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
|
||||
require.Equal(t, 0, state.usage.OutputTokens)
|
||||
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||
|
||||
parseUsageAndAccumulate(
|
||||
state,
|
||||
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||
"response.completed",
|
||||
nil,
|
||||
)
|
||||
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
|
||||
require.Equal(t, 0, state.usage.OutputTokens)
|
||||
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
||||
require.Equal(t, 2, state.usage.InputTokens)
|
||||
require.Equal(t, 1, state.usage.OutputTokens)
|
||||
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||
|
||||
result := &RelayResult{}
|
||||
enrichResult(result, state, 5*time.Millisecond)
|
||||
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||
require.Equal(t, 2, state.usage.InputTokens)
|
||||
enrichResult(nil, state, 0)
|
||||
}
|
||||
|
||||
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 非 terminal 事件不应触发。
|
||||
called := 0
|
||||
emitTurnComplete(func(turn RelayTurnResult) {
|
||||
called++
|
||||
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||
terminal: false,
|
||||
eventType: "response.output_text.delta",
|
||||
responseID: "resp_ignored",
|
||||
usage: Usage{InputTokens: 1},
|
||||
})
|
||||
require.Equal(t, 0, called)
|
||||
|
||||
// 缺少 response_id 时不应触发。
|
||||
emitTurnComplete(func(turn RelayTurnResult) {
|
||||
called++
|
||||
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||
terminal: true,
|
||||
eventType: "response.completed",
|
||||
})
|
||||
require.Equal(t, 0, called)
|
||||
|
||||
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
|
||||
var got RelayTurnResult
|
||||
emitTurnComplete(func(turn RelayTurnResult) {
|
||||
called++
|
||||
got = turn
|
||||
}, nil, observedUpstreamEvent{
|
||||
terminal: true,
|
||||
eventType: "response.completed",
|
||||
responseID: "resp_emit",
|
||||
usage: Usage{InputTokens: 2, OutputTokens: 3},
|
||||
})
|
||||
require.Equal(t, 1, called)
|
||||
require.Equal(t, "resp_emit", got.RequestID)
|
||||
require.Equal(t, "response.completed", got.TerminalEventType)
|
||||
require.Equal(t, 2, got.Usage.InputTokens)
|
||||
require.Equal(t, 3, got.Usage.OutputTokens)
|
||||
require.Equal(t, "", got.RequestModel)
|
||||
}
|
||||
|
||||
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
|
||||
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
|
||||
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
|
||||
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
|
||||
require.False(t, isDisconnectError(errors.New(" ")))
|
||||
}
|
||||
|
||||
func TestIsTokenEventCoverageBranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.False(t, isTokenEvent("response.in_progress"))
|
||||
require.False(t, isTokenEvent("response.output_item.added"))
|
||||
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||
require.True(t, isTokenEvent("response.output"))
|
||||
require.True(t, isTokenEvent("response.done"))
|
||||
}
|
||||
|
||||
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(100, 0)
|
||||
// nil state
|
||||
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
|
||||
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
|
||||
require.False(t, ok)
|
||||
|
||||
state := &relayState{}
|
||||
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
|
||||
require.NotNil(t, timing)
|
||||
require.Equal(t, now, timing.startAt)
|
||||
|
||||
// 再次获取返回同一条 timing
|
||||
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
|
||||
require.NotNil(t, timing2)
|
||||
require.Equal(t, now, timing2.startAt)
|
||||
|
||||
// 删除存在键
|
||||
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, now, deleted.startAt)
|
||||
|
||||
// 删除不存在键
|
||||
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &relayState{requestModel: "gpt-5"}
|
||||
startAt := time.Unix(0, 0)
|
||||
now := startAt
|
||||
nowFn := func() time.Time {
|
||||
now = now.Add(5 * time.Millisecond)
|
||||
return now
|
||||
}
|
||||
|
||||
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
|
||||
observed := observeUpstreamMessage(
|
||||
state,
|
||||
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
|
||||
startAt,
|
||||
nowFn,
|
||||
nil,
|
||||
)
|
||||
require.False(t, observed.terminal)
|
||||
require.Equal(t, "", observed.responseID)
|
||||
|
||||
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
|
||||
observed = observeUpstreamMessage(
|
||||
state,
|
||||
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
startAt,
|
||||
nowFn,
|
||||
nil,
|
||||
)
|
||||
require.True(t, observed.terminal)
|
||||
require.Equal(t, "resp_fallback", observed.responseID)
|
||||
}
|
||||
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type passthroughTestFrame struct {
|
||||
msgType coderws.MessageType
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type passthroughTestFrameConn struct {
|
||||
mu sync.Mutex
|
||||
writes []passthroughTestFrame
|
||||
readCh chan passthroughTestFrame
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type delayedReadFrameConn struct {
|
||||
base FrameConn
|
||||
firstDelay time.Duration
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type closeSpyFrameConn struct {
|
||||
closeCalls atomic.Int32
|
||||
}
|
||||
|
||||
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
|
||||
c := &passthroughTestFrameConn{
|
||||
readCh: make(chan passthroughTestFrame, len(frames)+1),
|
||||
}
|
||||
for _, frame := range frames {
|
||||
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
|
||||
c.readCh <- copied
|
||||
}
|
||||
if autoClose {
|
||||
close(c.readCh)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return coderws.MessageText, nil, ctx.Err()
|
||||
case frame, ok := <-c.readCh:
|
||||
if !ok {
|
||||
return coderws.MessageText, nil, io.EOF
|
||||
}
|
||||
return frame.msgType, append([]byte(nil), frame.payload...), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *passthroughTestFrameConn) Close() error {
|
||||
c.once.Do(func() {
|
||||
defer func() { _ = recover() }()
|
||||
close(c.readCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
out := make([]passthroughTestFrame, len(c.writes))
|
||||
copy(out, c.writes)
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if c == nil || c.base == nil {
|
||||
return coderws.MessageText, nil, io.EOF
|
||||
}
|
||||
c.once.Do(func() {
|
||||
if c.firstDelay > 0 {
|
||||
timer := time.NewTimer(c.firstDelay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
})
|
||||
return c.base.ReadFrame(ctx)
|
||||
}
|
||||
|
||||
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if c == nil || c.base == nil {
|
||||
return io.EOF
|
||||
}
|
||||
return c.base.WriteFrame(ctx, msgType, payload)
|
||||
}
|
||||
|
||||
func (c *delayedReadFrameConn) Close() error {
|
||||
if c == nil || c.base == nil {
|
||||
return nil
|
||||
}
|
||||
return c.base.Close()
|
||||
}
|
||||
|
||||
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
<-ctx.Done()
|
||||
return coderws.MessageText, nil, ctx.Err()
|
||||
}
|
||||
|
||||
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *closeSpyFrameConn) Close() error {
|
||||
if c != nil {
|
||||
c.closeCalls.Add(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *closeSpyFrameConn) CloseCalls() int32 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return c.closeCalls.Load()
|
||||
}
|
||||
|
||||
func TestRelay_BasicRelayAndUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
|
||||
require.Equal(t, "resp_123", result.RequestID)
|
||||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||
require.Equal(t, 7, result.Usage.InputTokens)
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||
require.Equal(t, int64(1), result.UpstreamToClientFrames)
|
||||
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
|
||||
|
||||
upstreamWrites := upstreamConn.Writes()
|
||||
require.Len(t, upstreamWrites, 1)
|
||||
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
|
||||
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 1)
|
||||
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
|
||||
}
|
||||
|
||||
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
|
||||
upstreamWrites := upstreamConn.Writes()
|
||||
require.Len(t, upstreamWrites, 1)
|
||||
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||
}
|
||||
|
||||
func TestRelay_UpstreamDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 上游立即关闭(EOF),客户端不发送额外帧
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
// 上游 EOF 属于 disconnect,标记为 graceful
|
||||
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
|
||||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||
}
|
||||
|
||||
func TestRelay_ClientDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
|
||||
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
|
||||
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||
}
|
||||
|
||||
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
|
||||
},
|
||||
}, true)
|
||||
upstreamConn := &delayedReadFrameConn{
|
||||
base: upstreamBase,
|
||||
firstDelay: 80 * time.Millisecond,
|
||||
}
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
UpstreamDrainTimeout: 400 * time.Millisecond,
|
||||
})
|
||||
require.NotNil(t, relayExit)
|
||||
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||
require.Equal(t, "resp_drain", result.RequestID)
|
||||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||
require.Equal(t, 6, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||
require.Equal(t, int64(0), result.UpstreamToClientFrames)
|
||||
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
|
||||
}
|
||||
|
||||
func TestRelay_IdleTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 客户端和上游都不发送帧,idle timeout 应触发
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用快进时间来加速 idle timeout
|
||||
now := time.Now()
|
||||
callCount := 0
|
||||
nowFn := func() time.Time {
|
||||
callCount++
|
||||
// 前几次调用返回正常时间(初始化阶段),之后快进
|
||||
if callCount <= 5 {
|
||||
return now
|
||||
}
|
||||
return now.Add(time.Hour) // 快进到超时
|
||||
}
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
IdleTimeout: 2 * time.Second,
|
||||
Now: nowFn,
|
||||
})
|
||||
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||
}
|
||||
|
||||
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := &closeSpyFrameConn{}
|
||||
upstreamConn := &closeSpyFrameConn{}
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
callCount := 0
|
||||
nowFn := func() time.Time {
|
||||
callCount++
|
||||
if callCount <= 5 {
|
||||
return now
|
||||
}
|
||||
return now.Add(time.Hour)
|
||||
}
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
IdleTimeout: 2 * time.Second,
|
||||
Now: nowFn,
|
||||
})
|
||||
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
|
||||
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
|
||||
}
|
||||
|
||||
func TestRelay_NilConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("nil client conn", func(t *testing.T) {
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.NotNil(t, relayExit)
|
||||
require.Equal(t, "relay_init", relayExit.Stage)
|
||||
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||
})
|
||||
|
||||
t.Run("nil upstream conn", func(t *testing.T) {
|
||||
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
|
||||
require.NotNil(t, relayExit)
|
||||
require.Equal(t, "relay_init", relayExit.Stage)
|
||||
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
|
||||
},
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
|
||||
},
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
require.Equal(t, "resp_multi", result.RequestID)
|
||||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
|
||||
// 验证所有 3 个上游帧都转发给了客户端
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 3)
|
||||
}
|
||||
|
||||
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||
},
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
turns := make([]RelayTurnResult, 0, 2)
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
OnTurnComplete: func(turn RelayTurnResult) {
|
||||
turns = append(turns, turn)
|
||||
},
|
||||
})
|
||||
require.Nil(t, relayExit)
|
||||
require.Len(t, turns, 2)
|
||||
require.Equal(t, "resp_turn_1", turns[0].RequestID)
|
||||
require.Equal(t, "response.completed", turns[0].TerminalEventType)
|
||||
require.Equal(t, 2, turns[0].Usage.InputTokens)
|
||||
require.Equal(t, 1, turns[0].Usage.OutputTokens)
|
||||
require.Equal(t, "resp_turn_2", turns[1].RequestID)
|
||||
require.Equal(t, "response.failed", turns[1].TerminalEventType)
|
||||
require.Equal(t, 3, turns[1].Usage.InputTokens)
|
||||
require.Equal(t, 4, turns[1].Usage.OutputTokens)
|
||||
require.Equal(t, 5, result.Usage.InputTokens)
|
||||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
|
||||
},
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
base := time.Unix(0, 0)
|
||||
var nowTick atomic.Int64
|
||||
nowFn := func() time.Time {
|
||||
step := nowTick.Add(1)
|
||||
return base.Add(time.Duration(step) * 5 * time.Millisecond)
|
||||
}
|
||||
|
||||
var turn RelayTurnResult
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
Now: nowFn,
|
||||
OnTurnComplete: func(current RelayTurnResult) {
|
||||
turn = current
|
||||
},
|
||||
})
|
||||
require.Nil(t, relayExit)
|
||||
require.Equal(t, "resp_metric", turn.RequestID)
|
||||
require.Equal(t, "response.completed", turn.TerminalEventType)
|
||||
require.NotNil(t, turn.FirstTokenMs)
|
||||
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
|
||||
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.Greater(t, result.Duration.Milliseconds(), int64(0))
|
||||
}
|
||||
|
||||
func TestRelay_BinaryFramePassthrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 验证 binary frame 被透传但不进行 usage 解析
|
||||
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageBinary,
|
||||
payload: binaryPayload,
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
// binary frame 不解析 usage
|
||||
require.Equal(t, 0, result.Usage.InputTokens)
|
||||
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 1)
|
||||
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||
require.Equal(t, binaryPayload, clientWrites[0].payload)
|
||||
}
|
||||
|
||||
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageBinary,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
require.Equal(t, 0, result.Usage.InputTokens)
|
||||
require.Equal(t, "", result.RequestID)
|
||||
require.Equal(t, "", result.TerminalEventType)
|
||||
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 1)
|
||||
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||
}
|
||||
|
||||
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: errorEvent,
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 1)
|
||||
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||
require.Equal(t, errorEvent, clientWrites[0].payload)
|
||||
}
|
||||
|
||||
func TestRelay_PreservesFirstMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
FirstMessageType: coderws.MessageBinary,
|
||||
})
|
||||
require.Nil(t, relayExit)
|
||||
|
||||
upstreamWrites := upstreamConn.Writes()
|
||||
require.Len(t, upstreamWrites, 1)
|
||||
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
|
||||
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||
}
|
||||
|
||||
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
|
||||
baseline := SnapshotMetrics().UsageParseFailureTotal
|
||||
|
||||
// 上游发送无效 JSON(非 usage 格式),不应影响透传
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
require.Nil(t, relayExit)
|
||||
// usage 解析失败,值为 0 但不影响透传
|
||||
require.Equal(t, 0, result.Usage.InputTokens)
|
||||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||
|
||||
// 帧仍然被转发
|
||||
clientWrites := clientConn.Writes()
|
||||
require.Len(t, clientWrites, 1)
|
||||
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
|
||||
}
|
||||
|
||||
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 上游连接立即关闭,首包写入失败
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||
_ = upstreamConn.Close()
|
||||
|
||||
// 覆盖 WriteFrame 使其返回错误
|
||||
errConn := &errorOnWriteFrameConn{}
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
|
||||
require.NotNil(t, relayExit)
|
||||
require.Equal(t, "write_upstream", relayExit.Stage)
|
||||
}
|
||||
|
||||
func TestRelay_ContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
|
||||
// 立即取消 context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||
// context 取消导致写首包失败
|
||||
require.NotNil(t, relayExit)
|
||||
}
|
||||
|
||||
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{
|
||||
msgType: coderws.MessageText,
|
||||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}, true)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stages := make([]string, 0, 8)
|
||||
var stagesMu sync.Mutex
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
OnTrace: func(event RelayTraceEvent) {
|
||||
stagesMu.Lock()
|
||||
stages = append(stages, event.Stage)
|
||||
stagesMu.Unlock()
|
||||
},
|
||||
})
|
||||
require.Nil(t, relayExit)
|
||||
stagesMu.Lock()
|
||||
capturedStages := append([]string(nil), stages...)
|
||||
stagesMu.Unlock()
|
||||
require.Contains(t, capturedStages, "relay_start")
|
||||
require.Contains(t, capturedStages, "write_first_message_ok")
|
||||
require.Contains(t, capturedStages, "first_exit")
|
||||
require.Contains(t, capturedStages, "relay_complete")
|
||||
}
|
||||
|
||||
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||
|
||||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
callCount := 0
|
||||
nowFn := func() time.Time {
|
||||
callCount++
|
||||
if callCount <= 5 {
|
||||
return now
|
||||
}
|
||||
return now.Add(time.Hour)
|
||||
}
|
||||
|
||||
stages := make([]string, 0, 8)
|
||||
var stagesMu sync.Mutex
|
||||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||
IdleTimeout: 2 * time.Second,
|
||||
Now: nowFn,
|
||||
OnTrace: func(event RelayTraceEvent) {
|
||||
stagesMu.Lock()
|
||||
stages = append(stages, event.Stage)
|
||||
stagesMu.Unlock()
|
||||
},
|
||||
})
|
||||
require.NotNil(t, relayExit)
|
||||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||
stagesMu.Lock()
|
||||
capturedStages := append([]string(nil), stages...)
|
||||
stagesMu.Unlock()
|
||||
require.Contains(t, capturedStages, "idle_timeout_triggered")
|
||||
require.Contains(t, capturedStages, "relay_exit")
|
||||
}
|
||||
|
||||
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
|
||||
type errorOnWriteFrameConn struct{}
|
||||
|
||||
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
<-ctx.Done()
|
||||
return coderws.MessageText, nil, ctx.Err()
|
||||
}
|
||||
|
||||
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
|
||||
return errors.New("write failed: connection refused")
|
||||
}
|
||||
|
||||
func (c *errorOnWriteFrameConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIWSClientFrameConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
|
||||
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if c == nil || c.conn == nil {
|
||||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Read(ctx)
|
||||
}
|
||||
|
||||
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Write(ctx, msgType, payload)
|
||||
}
|
||||
|
||||
func (c *openAIWSClientFrameConn) Close() error {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||
_ = c.conn.CloseNow()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
clientConn *coderws.Conn,
|
||||
account *Account,
|
||||
token string,
|
||||
firstClientMessage []byte,
|
||||
hooks *OpenAIWSIngressHooks,
|
||||
wsDecision OpenAIWSProtocolDecision,
|
||||
) error {
|
||||
if s == nil {
|
||||
return errors.New("service is nil")
|
||||
}
|
||||
if clientConn == nil {
|
||||
return errors.New("client websocket is nil")
|
||||
}
|
||||
if account == nil {
|
||||
return errors.New("account is nil")
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||
account.ID,
|
||||
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
||||
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
||||
len(firstClientMessage),
|
||||
)
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build ws url: %w", err)
|
||||
}
|
||||
wsHost := "-"
|
||||
wsPath := "-"
|
||||
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||||
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||||
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||||
}
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
||||
account.ID,
|
||||
wsHost,
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
|
||||
isCodexCLI := false
|
||||
if c != nil {
|
||||
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
}
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
isCodexCLI = true
|
||||
}
|
||||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
dialer := s.getOpenAIWSPassthroughDialer()
|
||||
if dialer == nil {
|
||||
return errors.New("openai ws passthrough dialer is nil")
|
||||
}
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
||||
defer cancelDial()
|
||||
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
||||
if err != nil {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
||||
account.ID,
|
||||
statusCode,
|
||||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||||
)
|
||||
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||||
}
|
||||
defer func() {
|
||||
_ = upstreamConn.Close()
|
||||
}()
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
||||
account.ID,
|
||||
statusCode,
|
||||
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
||||
)
|
||||
|
||||
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
||||
if !ok {
|
||||
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
||||
}
|
||||
|
||||
completedTurns := atomic.Int32{}
|
||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||
Ctx: ctx,
|
||||
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
||||
UpstreamConn: upstreamFrameConn,
|
||||
FirstClientMessage: firstClientMessage,
|
||||
Options: openaiwsv2.RelayOptions{
|
||||
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||
FirstMessageType: coderws.MessageText,
|
||||
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"usage_parse_failed event_type=%s usage_raw=%s",
|
||||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
||||
)
|
||||
},
|
||||
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
||||
turnNo := int(completedTurns.Add(1))
|
||||
turnResult := &OpenAIForwardResult{
|
||||
RequestID: turn.RequestID,
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: turn.Usage.InputTokens,
|
||||
OutputTokens: turn.Usage.OutputTokens,
|
||||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
Duration: turn.Duration,
|
||||
FirstTokenMs: turn.FirstTokenMs,
|
||||
}
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||||
account.ID,
|
||||
turnNo,
|
||||
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
||||
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
||||
turnResult.Duration.Milliseconds(),
|
||||
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
||||
turnResult.Usage.InputTokens,
|
||||
turnResult.Usage.OutputTokens,
|
||||
turnResult.Usage.CacheReadInputTokens,
|
||||
)
|
||||
if hooks != nil && hooks.AfterTurn != nil {
|
||||
hooks.AfterTurn(turnNo, turnResult, nil)
|
||||
}
|
||||
},
|
||||
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||||
account.ID,
|
||||
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
||||
event.PayloadBytes,
|
||||
event.Graceful,
|
||||
event.WroteDownstream,
|
||||
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
||||
)
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
result := &OpenAIForwardResult{
|
||||
RequestID: relayResult.RequestID,
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: relayResult.Usage.InputTokens,
|
||||
OutputTokens: relayResult.Usage.OutputTokens,
|
||||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
Duration: relayResult.Duration,
|
||||
FirstTokenMs: relayResult.FirstTokenMs,
|
||||
}
|
||||
|
||||
turnCount := int(completedTurns.Load())
|
||||
if relayExit == nil {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||
account.ID,
|
||||
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
||||
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
||||
result.Duration.Milliseconds(),
|
||||
relayResult.ClientToUpstreamFrames,
|
||||
relayResult.UpstreamToClientFrames,
|
||||
relayResult.DroppedDownstreamFrames,
|
||||
turnCount,
|
||||
)
|
||||
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
||||
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
||||
hooks.AfterTurn(1, result, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||
account.ID,
|
||||
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
||||
relayExit.WroteDownstream,
|
||||
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
||||
result.Duration.Milliseconds(),
|
||||
relayResult.ClientToUpstreamFrames,
|
||||
relayResult.UpstreamToClientFrames,
|
||||
relayResult.DroppedDownstreamFrames,
|
||||
turnCount,
|
||||
)
|
||||
|
||||
relayErr := relayExit.Err
|
||||
if relayExit.Stage == "idle_timeout" {
|
||||
relayErr = NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"client websocket idle timeout",
|
||||
relayErr,
|
||||
)
|
||||
}
|
||||
turnErr := wrapOpenAIWSIngressTurnError(
|
||||
relayExit.Stage,
|
||||
relayErr,
|
||||
relayExit.WroteDownstream,
|
||||
)
|
||||
if hooks != nil && hooks.AfterTurn != nil {
|
||||
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
||||
}
|
||||
return turnErr
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
||||
err error,
|
||||
statusCode int,
|
||||
handshakeHeaders http.Header,
|
||||
) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
wrappedErr := err
|
||||
var dialErr *openAIWSDialError
|
||||
if !errors.As(err, &dialErr) {
|
||||
wrappedErr = &openAIWSDialError{
|
||||
StatusCode: statusCode,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusTryAgainLater,
|
||||
"upstream websocket connect timeout",
|
||||
wrappedErr,
|
||||
)
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusTryAgainLater,
|
||||
"upstream websocket is busy, please retry later",
|
||||
wrappedErr,
|
||||
)
|
||||
}
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"upstream websocket authentication failed",
|
||||
wrappedErr,
|
||||
)
|
||||
}
|
||||
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"upstream websocket handshake rejected",
|
||||
wrappedErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
||||
}
|
||||
|
||||
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
return "text"
|
||||
case coderws.MessageBinary:
|
||||
return "binary"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
func relayErrorText(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
||||
if firstTokenMs == nil {
|
||||
return -1
|
||||
}
|
||||
return *firstTokenMs
|
||||
}
|
||||
|
||||
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_ws_v2",
|
||||
"[OpenAI WS v2 passthrough] %s "+format,
|
||||
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
||||
)
|
||||
}
|
||||
@@ -64,8 +64,12 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
|
||||
unique[acc.ID] = acc.Concurrency
|
||||
c := acc.Concurrency
|
||||
if c <= 0 {
|
||||
c = 1
|
||||
}
|
||||
if prev, ok := unique[acc.ID]; !ok || c > prev {
|
||||
unique[acc.ID] = c
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
maxConc := acc.Concurrency
|
||||
if maxConc < 0 {
|
||||
maxConc = 0
|
||||
}
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: maxConc,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
if len(batch) == 0 {
|
||||
|
||||
@@ -21,8 +21,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
|
||||
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
|
||||
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
openAIGPT54FallbackPricing = &LiteLLMModelPricing{
|
||||
InputCostPerToken: 2.5e-06, // $2.5 per MTok
|
||||
OutputCostPerToken: 1.5e-05, // $15 per MTok
|
||||
CacheReadInputTokenCost: 2.5e-07, // $0.25 per MTok
|
||||
LongContextInputTokenThreshold: 272000,
|
||||
LongContextInputCostMultiplier: 2.0,
|
||||
LongContextOutputCostMultiplier: 1.5,
|
||||
LiteLLMProvider: "openai",
|
||||
Mode: "chat",
|
||||
SupportsPromptCaching: true,
|
||||
}
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
@@ -33,6 +44,9 @@ type LiteLLMModelPricing struct {
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"`
|
||||
LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"`
|
||||
LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
@@ -660,7 +674,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
|
||||
// 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 4. gpt-5.3-codex -> gpt-5.2-codex
|
||||
// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
// 5. gpt-5.4* -> 业务静态兜底价
|
||||
// 6. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
if strings.HasPrefix(model, "gpt-5.3-codex-spark") {
|
||||
if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok {
|
||||
@@ -690,6 +705,12 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gpt-5.4") {
|
||||
logger.With(zap.String("component", "service.pricing")).
|
||||
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)"))
|
||||
return openAIGPT54FallbackPricing
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -51,3 +51,20 @@ func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) {
|
||||
require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn"))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.1-codex": &LiteLLMModelPricing{InputCostPerToken: 1.25e-6},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.GetModelPricing("gpt-5.4")
|
||||
require.NotNil(t, got)
|
||||
require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12)
|
||||
require.Equal(t, 272000, got.LongContextInputTokenThreshold)
|
||||
require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -1091,6 +1091,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
||||
if !account.IsTempUnschedulableEnabled() {
|
||||
return false
|
||||
}
|
||||
// 401 首次命中可临时不可调度(给 token 刷新窗口);
|
||||
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
reason := account.TempUnschedulableReason
|
||||
// 缓存可能没有 reason,从 DB 回退读取
|
||||
if reason == "" {
|
||||
if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil {
|
||||
reason = dbAcc.TempUnschedulableReason
|
||||
}
|
||||
}
|
||||
if wasTempUnschedByStatusCode(reason, statusCode) {
|
||||
slog.Info("401_escalated_to_error", "account_id", account.ID,
|
||||
"reason", "previous temp-unschedulable was also 401")
|
||||
return false
|
||||
}
|
||||
}
|
||||
rules := account.GetTempUnschedulableRules()
|
||||
if len(rules) == 0 {
|
||||
return false
|
||||
@@ -1122,6 +1138,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
||||
return false
|
||||
}
|
||||
|
||||
func wasTempUnschedByStatusCode(reason string, statusCode int) bool {
|
||||
if statusCode <= 0 {
|
||||
return false
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var state TempUnschedState
|
||||
if err := json.Unmarshal([]byte(reason), &state); err != nil {
|
||||
return false
|
||||
}
|
||||
return state.StatusCode == statusCode
|
||||
}
|
||||
|
||||
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
|
||||
if bodyLower == "" {
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account
|
||||
// returned by GetByID, simulating cache miss + DB fallback.
|
||||
type dbFallbackRepoStub struct {
|
||||
errorPolicyRepoStub
|
||||
dbAccount *Account // returned by GetByID when non-nil
|
||||
}
|
||||
|
||||
func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if r.dbAccount != nil && r.dbAccount.ID == id {
|
||||
return r.dbAccount, nil
|
||||
}
|
||||
return nil, nil // not found, no error
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason,
|
||||
// DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled).
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 21,
|
||||
TempUnschedulableReason: "", // DB also empty
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 21,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule")
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason,
|
||||
// DB lookup returns nil (not found) → should treat as first hit → temp unscheduled.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: nil, // GetByID returns nil, nil
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 22,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule")
|
||||
}
|
||||
51
backend/internal/service/scheduled_test_port.go
Normal file
51
backend/internal/service/scheduled_test_port.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ScheduledTestPlan represents a scheduled test plan domain model.
|
||||
type ScheduledTestPlan struct {
|
||||
ID int64 `json:"id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
NextRunAt *time.Time `json:"next_run_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ScheduledTestResult represents a single test execution result.
|
||||
type ScheduledTestResult struct {
|
||||
ID int64 `json:"id"`
|
||||
PlanID int64 `json:"plan_id"`
|
||||
Status string `json:"status"`
|
||||
ResponseText string `json:"response_text"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt time.Time `json:"finished_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ScheduledTestPlanRepository defines the data access interface for test plans.
|
||||
type ScheduledTestPlanRepository interface {
|
||||
Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
|
||||
GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error)
|
||||
ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error)
|
||||
ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error)
|
||||
Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error
|
||||
}
|
||||
|
||||
// ScheduledTestResultRepository defines the data access interface for test results.
|
||||
type ScheduledTestResultRepository interface {
|
||||
Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error)
|
||||
ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error)
|
||||
PruneOldResults(ctx context.Context, planID int64, keepCount int) error
|
||||
}
|
||||
139
backend/internal/service/scheduled_test_runner_service.go
Normal file
139
backend/internal/service/scheduled_test_runner_service.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
const scheduledTestDefaultMaxWorkers = 10
|
||||
|
||||
// ScheduledTestRunnerService periodically scans due test plans and executes them.
|
||||
type ScheduledTestRunnerService struct {
|
||||
planRepo ScheduledTestPlanRepository
|
||||
scheduledSvc *ScheduledTestService
|
||||
accountTestSvc *AccountTestService
|
||||
cfg *config.Config
|
||||
|
||||
cron *cron.Cron
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewScheduledTestRunnerService creates a new runner.
|
||||
func NewScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
return &ScheduledTestRunnerService{
|
||||
planRepo: planRepo,
|
||||
scheduledSvc: scheduledSvc,
|
||||
accountTestSvc: accountTestSvc,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the cron ticker (every minute).
|
||||
func (s *ScheduledTestRunnerService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
loc := time.Local
|
||||
if s.cfg != nil {
|
||||
if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil {
|
||||
loc = parsed
|
||||
}
|
||||
}
|
||||
|
||||
c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc))
|
||||
_, err := c.AddFunc("* * * * *", func() { s.runScheduled() })
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)")
|
||||
})
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the cron scheduler.
|
||||
func (s *ScheduledTestRunnerService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.cron != nil {
|
||||
ctx := s.cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ScheduledTestRunnerService) runScheduled() {
|
||||
// Delay 10s so execution lands at ~:10 of each minute instead of :00.
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
plans, err := s.planRepo.ListDue(ctx, now)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err)
|
||||
return
|
||||
}
|
||||
if len(plans) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans))
|
||||
|
||||
sem := make(chan struct{}, scheduledTestDefaultMaxWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, plan := range plans {
|
||||
sem <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func(p *ScheduledTestPlan) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
s.runOnePlan(ctx, p)
|
||||
}(plan)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) {
|
||||
result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
|
||||
}
|
||||
}
|
||||
94
backend/internal/service/scheduled_test_service.go
Normal file
94
backend/internal/service/scheduled_test_service.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
|
||||
// ScheduledTestService provides CRUD operations for scheduled test plans and results.
|
||||
type ScheduledTestService struct {
|
||||
planRepo ScheduledTestPlanRepository
|
||||
resultRepo ScheduledTestResultRepository
|
||||
}
|
||||
|
||||
// NewScheduledTestService creates a new ScheduledTestService.
|
||||
func NewScheduledTestService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
resultRepo ScheduledTestResultRepository,
|
||||
) *ScheduledTestService {
|
||||
return &ScheduledTestService{
|
||||
planRepo: planRepo,
|
||||
resultRepo: resultRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePlan validates the cron expression, computes next_run_at, and persists the plan.
|
||||
func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cron expression: %w", err)
|
||||
}
|
||||
plan.NextRunAt = &nextRun
|
||||
|
||||
if plan.MaxResults <= 0 {
|
||||
plan.MaxResults = 50
|
||||
}
|
||||
|
||||
return s.planRepo.Create(ctx, plan)
|
||||
}
|
||||
|
||||
// GetPlan retrieves a plan by ID.
|
||||
func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) {
|
||||
return s.planRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListPlansByAccount returns all plans for a given account.
|
||||
func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) {
|
||||
return s.planRepo.ListByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
// UpdatePlan validates cron and updates the plan.
|
||||
func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) {
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cron expression: %w", err)
|
||||
}
|
||||
plan.NextRunAt = &nextRun
|
||||
|
||||
return s.planRepo.Update(ctx, plan)
|
||||
}
|
||||
|
||||
// DeletePlan removes a plan and its results (via CASCADE).
|
||||
func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error {
|
||||
return s.planRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// ListResults returns the most recent results for a plan.
|
||||
func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
return s.resultRepo.ListByPlanID(ctx, planID, limit)
|
||||
}
|
||||
|
||||
// SaveResult inserts a result and prunes old entries beyond maxResults.
|
||||
func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error {
|
||||
result.PlanID = planID
|
||||
if _, err := s.resultRepo.Create(ctx, result); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.resultRepo.PruneOldResults(ctx, planID, maxResults)
|
||||
}
|
||||
|
||||
func computeNextRun(cronExpr string, from time.Time) (time.Time, error) {
|
||||
sched, err := scheduledTestCronParser.Parse(cronExpr)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return sched.Next(from), nil
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func TestCalculateProgress_BasicFields(t *testing.T) {
|
||||
assert.Equal(t, int64(100), progress.ID)
|
||||
assert.Equal(t, "Premium", progress.GroupName)
|
||||
assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt)
|
||||
assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天
|
||||
assert.True(t, progress.ExpiresInDays == 29 || progress.ExpiresInDays == 30, "ExpiresInDays should be 29 or 30, got %d", progress.ExpiresInDays)
|
||||
assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil")
|
||||
assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil")
|
||||
assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil")
|
||||
|
||||
@@ -274,6 +274,26 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideScheduledTestService creates ScheduledTestService.
|
||||
func ProvideScheduledTestService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
resultRepo ScheduledTestResultRepository,
|
||||
) *ScheduledTestService {
|
||||
return NewScheduledTestService(planRepo, resultRepo)
|
||||
}
|
||||
|
||||
// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService.
|
||||
func ProvideScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
||||
func ProvideOpsScheduledReportService(
|
||||
opsService *OpsService,
|
||||
@@ -380,4 +400,6 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideIdempotencyCoordinator,
|
||||
ProvideSystemOperationLockService,
|
||||
ProvideIdempotencyCleanupService,
|
||||
ProvideScheduledTestService,
|
||||
ProvideScheduledTestRunnerService,
|
||||
)
|
||||
|
||||
30
backend/migrations/066_add_scheduled_test_tables.sql
Normal file
30
backend/migrations/066_add_scheduled_test_tables.sql
Normal file
@@ -0,0 +1,30 @@
|
||||
-- 066_add_scheduled_test_tables.sql
|
||||
-- Scheduled account test plans and results
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scheduled_test_plans (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
model_id VARCHAR(100) NOT NULL DEFAULT '',
|
||||
cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *',
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
max_results INT NOT NULL DEFAULT 50,
|
||||
last_run_at TIMESTAMPTZ,
|
||||
next_run_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scheduled_test_results (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'success',
|
||||
response_text TEXT NOT NULL DEFAULT '',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
latency_ms BIGINT NOT NULL DEFAULT 0,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC);
|
||||
1
backend/migrations/067_add_account_load_factor.sql
Normal file
1
backend/migrations/067_add_account_load_factor.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS load_factor INTEGER;
|
||||
@@ -5140,6 +5140,39 @@
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"gpt-5.4": {
|
||||
"cache_read_input_token_cost": 2.5e-07,
|
||||
"input_cost_per_token": 2.5e-06,
|
||||
"litellm_provider": "openai",
|
||||
"max_input_tokens": 1050000,
|
||||
"max_output_tokens": 128000,
|
||||
"max_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.5e-05,
|
||||
"supported_endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/responses"
|
||||
],
|
||||
"supported_modalities": [
|
||||
"text",
|
||||
"image"
|
||||
],
|
||||
"supported_output_modalities": [
|
||||
"text",
|
||||
"image"
|
||||
],
|
||||
"supports_function_calling": true,
|
||||
"supports_native_streaming": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_pdf_input": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_service_tier": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"gpt-5.3-codex": {
|
||||
"cache_read_input_token_cost": 1.75e-07,
|
||||
"cache_read_input_token_cost_priority": 3.5e-07,
|
||||
|
||||
@@ -209,8 +209,9 @@ gateway:
|
||||
openai_ws:
|
||||
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
|
||||
mode_router_v2_enabled: false
|
||||
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
|
||||
ingress_mode_default: shared
|
||||
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
|
||||
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
|
||||
ingress_mode_default: ctx_pool
|
||||
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
|
||||
enabled: true
|
||||
# 按账号类型细分开关
|
||||
|
||||
@@ -240,6 +240,18 @@ export async function clearRateLimit(id: number): Promise<Account> {
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset account quota usage
|
||||
* @param id - Account ID
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function resetAccountQuota(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(
|
||||
`/admin/accounts/${id}/reset-quota`
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get temporary unschedulable status
|
||||
* @param id - Account ID
|
||||
@@ -576,6 +588,7 @@ export const accountsAPI = {
|
||||
getTodayStats,
|
||||
getBatchTodayStats,
|
||||
clearRateLimit,
|
||||
resetAccountQuota,
|
||||
getTempUnschedulableStatus,
|
||||
resetTempUnschedulable,
|
||||
setSchedulable,
|
||||
|
||||
@@ -22,6 +22,7 @@ import opsAPI from './ops'
|
||||
import errorPassthroughAPI from './errorPassthrough'
|
||||
import dataManagementAPI from './dataManagement'
|
||||
import apiKeysAPI from './apiKeys'
|
||||
import scheduledTestsAPI from './scheduledTests'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@@ -45,7 +46,8 @@ export const adminAPI = {
|
||||
ops: opsAPI,
|
||||
errorPassthrough: errorPassthroughAPI,
|
||||
dataManagement: dataManagementAPI,
|
||||
apiKeys: apiKeysAPI
|
||||
apiKeys: apiKeysAPI,
|
||||
scheduledTests: scheduledTestsAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@@ -67,7 +69,8 @@ export {
|
||||
opsAPI,
|
||||
errorPassthroughAPI,
|
||||
dataManagementAPI,
|
||||
apiKeysAPI
|
||||
apiKeysAPI,
|
||||
scheduledTestsAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
85
frontend/src/api/admin/scheduledTests.ts
Normal file
85
frontend/src/api/admin/scheduledTests.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Admin Scheduled Tests API endpoints
|
||||
* Handles scheduled test plan management for account connectivity monitoring
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client'
|
||||
import type {
|
||||
ScheduledTestPlan,
|
||||
ScheduledTestResult,
|
||||
CreateScheduledTestPlanRequest,
|
||||
UpdateScheduledTestPlanRequest
|
||||
} from '@/types'
|
||||
|
||||
/**
|
||||
* List all scheduled test plans for an account
|
||||
* @param accountId - Account ID
|
||||
* @returns List of scheduled test plans
|
||||
*/
|
||||
export async function listByAccount(accountId: number): Promise<ScheduledTestPlan[]> {
|
||||
const { data } = await apiClient.get<ScheduledTestPlan[]>(
|
||||
`/admin/accounts/${accountId}/scheduled-test-plans`
|
||||
)
|
||||
return data ?? []
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new scheduled test plan
|
||||
* @param req - Plan creation request
|
||||
* @returns Created plan
|
||||
*/
|
||||
export async function create(req: CreateScheduledTestPlanRequest): Promise<ScheduledTestPlan> {
|
||||
const { data } = await apiClient.post<ScheduledTestPlan>(
|
||||
'/admin/scheduled-test-plans',
|
||||
req
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Update an existing scheduled test plan
|
||||
* @param id - Plan ID
|
||||
* @param req - Fields to update
|
||||
* @returns Updated plan
|
||||
*/
|
||||
export async function update(id: number, req: UpdateScheduledTestPlanRequest): Promise<ScheduledTestPlan> {
|
||||
const { data } = await apiClient.put<ScheduledTestPlan>(
|
||||
`/admin/scheduled-test-plans/${id}`,
|
||||
req
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a scheduled test plan
|
||||
* @param id - Plan ID
|
||||
*/
|
||||
export async function deletePlan(id: number): Promise<void> {
|
||||
await apiClient.delete(`/admin/scheduled-test-plans/${id}`)
|
||||
}
|
||||
|
||||
/**
|
||||
* List test results for a plan
|
||||
* @param planId - Plan ID
|
||||
* @param limit - Optional max number of results to return
|
||||
* @returns List of test results
|
||||
*/
|
||||
export async function listResults(planId: number, limit?: number): Promise<ScheduledTestResult[]> {
|
||||
const { data } = await apiClient.get<ScheduledTestResult[]>(
|
||||
`/admin/scheduled-test-plans/${planId}/results`,
|
||||
{
|
||||
params: limit ? { limit } : undefined
|
||||
}
|
||||
)
|
||||
return data ?? []
|
||||
}
|
||||
|
||||
export const scheduledTestsAPI = {
|
||||
listByAccount,
|
||||
create,
|
||||
update,
|
||||
delete: deletePlan,
|
||||
listResults
|
||||
}
|
||||
|
||||
export default scheduledTestsAPI
|
||||
@@ -71,6 +71,24 @@
|
||||
<span class="text-[9px] opacity-60">{{ rpmStrategyTag }}</span>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="showQuotaLimit" class="flex items-center gap-1">
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
|
||||
quotaClass
|
||||
]"
|
||||
:title="quotaTooltip"
|
||||
>
|
||||
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M2.25 18.75a60.07 60.07 0 0115.797 2.101c.727.198 1.453-.342 1.453-1.096V18.75M3.75 4.5v.75A.75.75 0 013 6h-.75m0 0v-.375c0-.621.504-1.125 1.125-1.125H20.25M2.25 6v9m18-10.5v.75c0 .414.336.75.75.75h.75m-1.5-1.5h.375c.621 0 1.125.504 1.125 1.125v9.75c0 .621-.504 1.125-1.125 1.125h-.375m1.5-1.5H21a.75.75 0 00-.75.75v.75m0 0H3.75m0 0h-.375a1.125 1.125 0 01-1.125-1.125V15m1.5 1.5v-.75A.75.75 0 003 15h-.75M15 10.5a3 3 0 11-6 0 3 3 0 016 0zm3 0h.008v.008H18V10.5zm-12 0h.008v.008H6V10.5z" />
|
||||
</svg>
|
||||
<span class="font-mono">${{ formatCost(currentQuotaUsed) }}</span>
|
||||
<span class="text-gray-400 dark:text-gray-500">/</span>
|
||||
<span class="font-mono">${{ formatCost(account.quota_limit) }}</span>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -286,6 +304,48 @@ const rpmTooltip = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// 是否显示配额限制(仅 apikey 类型且设置了 quota_limit)
|
||||
const showQuotaLimit = computed(() => {
|
||||
return (
|
||||
props.account.type === 'apikey' &&
|
||||
props.account.quota_limit !== undefined &&
|
||||
props.account.quota_limit !== null &&
|
||||
props.account.quota_limit > 0
|
||||
)
|
||||
})
|
||||
|
||||
// 当前已用配额
|
||||
const currentQuotaUsed = computed(() => props.account.quota_used ?? 0)
|
||||
|
||||
// 配额状态样式
|
||||
const quotaClass = computed(() => {
|
||||
if (!showQuotaLimit.value) return ''
|
||||
|
||||
const used = currentQuotaUsed.value
|
||||
const limit = props.account.quota_limit || 0
|
||||
|
||||
if (used >= limit) {
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
|
||||
}
|
||||
if (used >= limit * 0.8) {
|
||||
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
}
|
||||
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
})
|
||||
|
||||
// 配额提示文字
|
||||
const quotaTooltip = computed(() => {
|
||||
if (!showQuotaLimit.value) return ''
|
||||
|
||||
const used = currentQuotaUsed.value
|
||||
const limit = props.account.quota_limit || 0
|
||||
|
||||
if (used >= limit) {
|
||||
return t('admin.accounts.capacity.quota.exceeded')
|
||||
}
|
||||
return t('admin.accounts.capacity.quota.normal')
|
||||
})
|
||||
|
||||
// 格式化费用显示
|
||||
const formatCost = (value: number | null | undefined) => {
|
||||
if (value === null || value === undefined) return '0'
|
||||
|
||||
@@ -469,7 +469,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Concurrency & Priority -->
|
||||
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600 lg:grid-cols-3">
|
||||
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600 lg:grid-cols-4">
|
||||
<div>
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
@@ -496,8 +496,39 @@
|
||||
class="input"
|
||||
:class="!enableConcurrency && 'cursor-not-allowed opacity-50'"
|
||||
aria-labelledby="bulk-edit-concurrency-label"
|
||||
@input="concurrency = Math.max(1, concurrency || 1)"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
id="bulk-edit-load-factor-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-load-factor-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.loadFactor') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="enableLoadFactor"
|
||||
id="bulk-edit-load-factor-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-load-factor"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<input
|
||||
v-model.number="loadFactor"
|
||||
id="bulk-edit-load-factor"
|
||||
type="number"
|
||||
min="1"
|
||||
:disabled="!enableLoadFactor"
|
||||
class="input"
|
||||
:class="!enableLoadFactor && 'cursor-not-allowed opacity-50'"
|
||||
aria-labelledby="bulk-edit-load-factor-label"
|
||||
@input="loadFactor = (loadFactor && loadFactor >= 1) ? loadFactor : null"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.loadFactorHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
@@ -869,6 +900,7 @@ const enableCustomErrorCodes = ref(false)
|
||||
const enableInterceptWarmup = ref(false)
|
||||
const enableProxy = ref(false)
|
||||
const enableConcurrency = ref(false)
|
||||
const enableLoadFactor = ref(false)
|
||||
const enablePriority = ref(false)
|
||||
const enableRateMultiplier = ref(false)
|
||||
const enableStatus = ref(false)
|
||||
@@ -889,6 +921,7 @@ const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
const proxyId = ref<number | null>(null)
|
||||
const concurrency = ref(1)
|
||||
const loadFactor = ref<number | null>(null)
|
||||
const priority = ref(1)
|
||||
const rateMultiplier = ref(1)
|
||||
const status = ref<'active' | 'inactive'>('active')
|
||||
@@ -918,6 +951,7 @@ const allModels = [
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||
{ value: 'gpt-5.3-codex', label: 'GPT-5.3 Codex' },
|
||||
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
|
||||
{ value: 'gpt-5.4', label: 'GPT-5.4' },
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
@@ -1032,6 +1066,12 @@ const presetMappings = [
|
||||
to: 'gpt-5.3-codex-spark',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.4',
|
||||
from: 'gpt-5.4',
|
||||
to: 'gpt-5.4',
|
||||
color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400'
|
||||
},
|
||||
{
|
||||
label: '5.2→5.3',
|
||||
from: 'gpt-5.2-codex',
|
||||
@@ -1195,6 +1235,12 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
updates.concurrency = concurrency.value
|
||||
}
|
||||
|
||||
if (enableLoadFactor.value) {
|
||||
// 空值/NaN/0 时发送 0(后端约定 <= 0 表示清除)
|
||||
const lf = loadFactor.value
|
||||
updates.load_factor = (lf != null && !Number.isNaN(lf) && lf > 0) ? lf : 0
|
||||
}
|
||||
|
||||
if (enablePriority.value) {
|
||||
updates.priority = priority.value
|
||||
}
|
||||
@@ -1340,6 +1386,7 @@ const handleSubmit = async () => {
|
||||
enableInterceptWarmup.value ||
|
||||
enableProxy.value ||
|
||||
enableConcurrency.value ||
|
||||
enableLoadFactor.value ||
|
||||
enablePriority.value ||
|
||||
enableRateMultiplier.value ||
|
||||
enableStatus.value ||
|
||||
@@ -1430,6 +1477,7 @@ watch(
|
||||
enableInterceptWarmup.value = false
|
||||
enableProxy.value = false
|
||||
enableConcurrency.value = false
|
||||
enableLoadFactor.value = false
|
||||
enablePriority.value = false
|
||||
enableRateMultiplier.value = false
|
||||
enableStatus.value = false
|
||||
@@ -1446,6 +1494,7 @@ watch(
|
||||
interceptWarmupRequests.value = false
|
||||
proxyId.value = null
|
||||
concurrency.value = 1
|
||||
loadFactor.value = null
|
||||
priority.value = 1
|
||||
rateMultiplier.value = 1
|
||||
status.value = 'active'
|
||||
|
||||
@@ -1227,6 +1227,9 @@
|
||||
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<QuotaLimitCard v-if="form.type === 'apikey'" v-model="editQuotaLimit" />
|
||||
|
||||
<!-- Temp Unschedulable Rules -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@@ -1749,10 +1752,18 @@
|
||||
<ProxySelector v-model="form.proxy_id" :proxies="proxies" />
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 gap-4 lg:grid-cols-3">
|
||||
<div class="grid grid-cols-2 gap-4 lg:grid-cols-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.concurrency') }}</label>
|
||||
<input v-model.number="form.concurrency" type="number" min="1" class="input" />
|
||||
<input v-model.number="form.concurrency" type="number" min="1" class="input"
|
||||
@input="form.concurrency = Math.max(1, form.concurrency || 1)" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.loadFactor') }}</label>
|
||||
<input v-model.number="form.load_factor" type="number" min="1"
|
||||
class="input" :placeholder="String(form.concurrency || 1)"
|
||||
@input="form.load_factor = (form.load_factor && form.load_factor >= 1) ? form.load_factor : null" />
|
||||
<p class="input-hint">{{ t('admin.accounts.loadFactorHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.priority') }}</label>
|
||||
@@ -1807,7 +1818,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI WS Mode 三态(off/shared/dedicated) -->
|
||||
<!-- OpenAI WS Mode 三态(off/ctx_pool/passthrough) -->
|
||||
<div
|
||||
v-if="form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
@@ -1819,7 +1830,7 @@
|
||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||
</p>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
|
||||
{{ t(openAIWSModeConcurrencyHintKey) }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="w-52">
|
||||
@@ -2337,14 +2348,16 @@ import Icon from '@/components/icons/Icon.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import {
|
||||
OPENAI_WS_MODE_DEDICATED,
|
||||
// OPENAI_WS_MODE_CTX_POOL,
|
||||
OPENAI_WS_MODE_OFF,
|
||||
OPENAI_WS_MODE_SHARED,
|
||||
OPENAI_WS_MODE_PASSTHROUGH,
|
||||
isOpenAIWSModeEnabled,
|
||||
resolveOpenAIWSModeConcurrencyHintKey,
|
||||
type OpenAIWSMode
|
||||
} from '@/utils/openaiWsMode'
|
||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||
@@ -2459,6 +2472,7 @@ const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selec
|
||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -2541,8 +2555,9 @@ const geminiSelectedTier = computed(() => {
|
||||
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
|
||||
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
|
||||
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
|
||||
// { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
|
||||
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||
])
|
||||
|
||||
const openaiResponsesWebSocketV2Mode = computed({
|
||||
@@ -2561,6 +2576,10 @@ const openaiResponsesWebSocketV2Mode = computed({
|
||||
}
|
||||
})
|
||||
|
||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
|
||||
const isOpenAIModelRestrictionDisabled = computed(() =>
|
||||
form.platform === 'openai' && openaiPassthroughEnabled.value
|
||||
)
|
||||
@@ -2627,6 +2646,7 @@ const form = reactive({
|
||||
credentials: {} as Record<string, unknown>,
|
||||
proxy_id: null as number | null,
|
||||
concurrency: 10,
|
||||
load_factor: null as number | null,
|
||||
priority: 1,
|
||||
rate_multiplier: 1,
|
||||
group_ids: [] as number[],
|
||||
@@ -3106,6 +3126,7 @@ const resetForm = () => {
|
||||
form.credentials = {}
|
||||
form.proxy_id = null
|
||||
form.concurrency = 10
|
||||
form.load_factor = null
|
||||
form.priority = 1
|
||||
form.rate_multiplier = 1
|
||||
form.group_ids = []
|
||||
@@ -3114,6 +3135,7 @@ const resetForm = () => {
|
||||
addMethod.value = 'oauth'
|
||||
apiKeyBaseUrl.value = 'https://api.anthropic.com'
|
||||
apiKeyValue.value = ''
|
||||
editQuotaLimit.value = null
|
||||
modelMappings.value = []
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = [...claudeModels] // Default fill related models
|
||||
@@ -3180,10 +3202,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
|
||||
}
|
||||
|
||||
const extra: Record<string, unknown> = { ...(base || {}) }
|
||||
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||
if (accountCategory.value === 'oauth-based') {
|
||||
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
} else if (accountCategory.value === 'apikey') {
|
||||
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||
}
|
||||
// 清理兼容旧键,统一改用分类型开关。
|
||||
delete extra.responses_websockets_v2_enabled
|
||||
delete extra.openai_ws_enabled
|
||||
@@ -3474,6 +3499,7 @@ const handleImportAccessToken = async (accessTokenInput: string) => {
|
||||
extra: soraExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3524,15 +3550,21 @@ const createAccountAndFinish = async (
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
// Inject quota_limit for apikey accounts
|
||||
let finalExtra = extra
|
||||
if (type === 'apikey' && editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
finalExtra = { ...(extra || {}), quota_limit: editQuotaLimit.value }
|
||||
}
|
||||
await doCreateAccount({
|
||||
name: form.name,
|
||||
notes: form.notes,
|
||||
platform,
|
||||
type,
|
||||
credentials,
|
||||
extra,
|
||||
extra: finalExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3588,6 +3620,7 @@ const handleOpenAIExchange = async (authCode: string) => {
|
||||
extra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3617,6 +3650,7 @@ const handleOpenAIExchange = async (authCode: string) => {
|
||||
extra: soraExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3694,6 +3728,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
extra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3721,6 +3756,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
extra: soraExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3809,6 +3845,7 @@ const handleSoraValidateST = async (sessionTokenInput: string) => {
|
||||
extra: soraExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -3897,6 +3934,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
||||
extra: {},
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
@@ -4055,8 +4093,11 @@ const handleAnthropicExchange = async (authCode: string) => {
|
||||
}
|
||||
|
||||
// Add RPM limit settings
|
||||
if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) {
|
||||
extra.base_rpm = baseRpm.value
|
||||
if (rpmLimitEnabled.value) {
|
||||
const DEFAULT_BASE_RPM = 15
|
||||
extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0)
|
||||
? baseRpm.value
|
||||
: DEFAULT_BASE_RPM
|
||||
extra.rpm_strategy = rpmStrategy.value
|
||||
if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) {
|
||||
extra.rpm_sticky_buffer = rpmStickyBuffer.value
|
||||
@@ -4167,8 +4208,11 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
}
|
||||
|
||||
// Add RPM limit settings
|
||||
if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) {
|
||||
extra.base_rpm = baseRpm.value
|
||||
if (rpmLimitEnabled.value) {
|
||||
const DEFAULT_BASE_RPM = 15
|
||||
extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0)
|
||||
? baseRpm.value
|
||||
: DEFAULT_BASE_RPM
|
||||
extra.rpm_strategy = rpmStrategy.value
|
||||
if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) {
|
||||
extra.rpm_sticky_buffer = rpmStickyBuffer.value
|
||||
@@ -4214,6 +4258,7 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
extra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
load_factor: form.load_factor ?? undefined,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
|
||||
@@ -650,10 +650,18 @@
|
||||
<ProxySelector v-model="form.proxy_id" :proxies="proxies" />
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 gap-4 lg:grid-cols-3">
|
||||
<div class="grid grid-cols-2 gap-4 lg:grid-cols-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.concurrency') }}</label>
|
||||
<input v-model.number="form.concurrency" type="number" min="1" class="input" />
|
||||
<input v-model.number="form.concurrency" type="number" min="1" class="input"
|
||||
@input="form.concurrency = Math.max(1, form.concurrency || 1)" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.loadFactor') }}</label>
|
||||
<input v-model.number="form.load_factor" type="number" min="1"
|
||||
class="input" :placeholder="String(form.concurrency || 1)"
|
||||
@input="form.load_factor = (form.load_factor && form.load_factor >= 1) ? form.load_factor : null" />
|
||||
<p class="input-hint">{{ t('admin.accounts.loadFactorHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.priority') }}</label>
|
||||
@@ -708,7 +716,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI WS Mode 三态(off/shared/dedicated) -->
|
||||
<!-- OpenAI WS Mode 三态(off/ctx_pool/passthrough) -->
|
||||
<div
|
||||
v-if="account?.platform === 'openai' && (account?.type === 'oauth' || account?.type === 'apikey')"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
@@ -720,7 +728,7 @@
|
||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||
</p>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
|
||||
{{ t(openAIWSModeConcurrencyHintKey) }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="w-52">
|
||||
@@ -759,6 +767,9 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<QuotaLimitCard v-if="account?.type === 'apikey'" v-model="editQuotaLimit" />
|
||||
|
||||
<!-- OpenAI OAuth Codex 官方客户端限制开关 -->
|
||||
<div
|
||||
v-if="account?.platform === 'openai' && account?.type === 'oauth'"
|
||||
@@ -1269,14 +1280,16 @@ import Icon from '@/components/icons/Icon.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import {
|
||||
OPENAI_WS_MODE_DEDICATED,
|
||||
// OPENAI_WS_MODE_CTX_POOL,
|
||||
OPENAI_WS_MODE_OFF,
|
||||
OPENAI_WS_MODE_SHARED,
|
||||
OPENAI_WS_MODE_PASSTHROUGH,
|
||||
isOpenAIWSModeEnabled,
|
||||
resolveOpenAIWSModeConcurrencyHintKey,
|
||||
type OpenAIWSMode,
|
||||
resolveOpenAIWSModeFromExtra
|
||||
} from '@/utils/openaiWsMode'
|
||||
@@ -1385,10 +1398,12 @@ const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF
|
||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const codexCLIOnlyEnabled = ref(false)
|
||||
const anthropicPassthroughEnabled = ref(false)
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
|
||||
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
|
||||
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
|
||||
// { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
|
||||
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||
])
|
||||
const openaiResponsesWebSocketV2Mode = computed({
|
||||
get: () => {
|
||||
@@ -1405,6 +1420,9 @@ const openaiResponsesWebSocketV2Mode = computed({
|
||||
openaiOAuthResponsesWebSocketV2Mode.value = mode
|
||||
}
|
||||
})
|
||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
const isOpenAIModelRestrictionDisabled = computed(() =>
|
||||
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
|
||||
)
|
||||
@@ -1460,6 +1478,7 @@ const form = reactive({
|
||||
notes: '',
|
||||
proxy_id: null as number | null,
|
||||
concurrency: 1,
|
||||
load_factor: null as number | null,
|
||||
priority: 1,
|
||||
rate_multiplier: 1,
|
||||
status: 'active' as 'active' | 'inactive',
|
||||
@@ -1493,9 +1512,12 @@ watch(
|
||||
form.notes = newAccount.notes || ''
|
||||
form.proxy_id = newAccount.proxy_id
|
||||
form.concurrency = newAccount.concurrency
|
||||
form.load_factor = newAccount.load_factor ?? null
|
||||
form.priority = newAccount.priority
|
||||
form.rate_multiplier = newAccount.rate_multiplier ?? 1
|
||||
form.status = newAccount.status as 'active' | 'inactive'
|
||||
form.status = (newAccount.status === 'active' || newAccount.status === 'inactive')
|
||||
? newAccount.status
|
||||
: 'active'
|
||||
form.group_ids = newAccount.group_ids || []
|
||||
form.expires_at = newAccount.expires_at ?? null
|
||||
|
||||
@@ -1536,6 +1558,14 @@ watch(
|
||||
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
|
||||
}
|
||||
|
||||
// Load quota limit for apikey accounts
|
||||
if (newAccount.type === 'apikey') {
|
||||
const quotaVal = extra?.quota_limit as number | undefined
|
||||
editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null
|
||||
} else {
|
||||
editQuotaLimit.value = null
|
||||
}
|
||||
|
||||
// Load antigravity model mapping (Antigravity 只支持映射模式)
|
||||
if (newAccount.platform === 'antigravity') {
|
||||
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
||||
@@ -2035,6 +2065,11 @@ const handleSubmit = async () => {
|
||||
if (!props.account) return
|
||||
const accountID = props.account.id
|
||||
|
||||
if (form.status !== 'active' && form.status !== 'inactive') {
|
||||
appStore.showError(t('admin.accounts.pleaseSelectStatus'))
|
||||
return
|
||||
}
|
||||
|
||||
const updatePayload: Record<string, unknown> = { ...form }
|
||||
try {
|
||||
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
|
||||
@@ -2044,6 +2079,11 @@ const handleSubmit = async () => {
|
||||
if (form.expires_at === null) {
|
||||
updatePayload.expires_at = 0
|
||||
}
|
||||
// load_factor: 空值/NaN/0/负数 时发送 0(后端约定 <= 0 = 清除)
|
||||
const lf = form.load_factor
|
||||
if (lf == null || Number.isNaN(lf) || lf <= 0) {
|
||||
updatePayload.load_factor = 0
|
||||
}
|
||||
updatePayload.auto_pause_on_expired = autoPauseOnExpired.value
|
||||
|
||||
// For apikey type, handle credentials update
|
||||
@@ -2183,8 +2223,11 @@ const handleSubmit = async () => {
|
||||
}
|
||||
|
||||
// RPM limit settings
|
||||
if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) {
|
||||
newExtra.base_rpm = baseRpm.value
|
||||
if (rpmLimitEnabled.value) {
|
||||
const DEFAULT_BASE_RPM = 15
|
||||
newExtra.base_rpm = (baseRpm.value != null && baseRpm.value > 0)
|
||||
? baseRpm.value
|
||||
: DEFAULT_BASE_RPM
|
||||
newExtra.rpm_strategy = rpmStrategy.value
|
||||
if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) {
|
||||
newExtra.rpm_sticky_buffer = rpmStickyBuffer.value
|
||||
@@ -2248,10 +2291,13 @@ const handleSubmit = async () => {
|
||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
|
||||
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||
if (props.account.type === 'oauth') {
|
||||
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
} else if (props.account.type === 'apikey') {
|
||||
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||
}
|
||||
delete newExtra.responses_websockets_v2_enabled
|
||||
delete newExtra.openai_ws_enabled
|
||||
if (openaiPassthroughEnabled.value) {
|
||||
@@ -2275,6 +2321,19 @@ const handleSubmit = async () => {
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
// For apikey accounts, handle quota_limit in extra
|
||||
if (props.account.type === 'apikey') {
|
||||
const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
|
||||
(props.account.extra as Record<string, unknown>) || {}
|
||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
newExtra.quota_limit = editQuotaLimit.value
|
||||
} else {
|
||||
delete newExtra.quota_limit
|
||||
}
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||
await submitUpdateAccount(accountID, updatePayload)
|
||||
})
|
||||
|
||||
92
frontend/src/components/account/QuotaLimitCard.vue
Normal file
92
frontend/src/components/account/QuotaLimitCard.vue
Normal file
@@ -0,0 +1,92 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: number | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: number | null]
|
||||
}>()
|
||||
|
||||
const enabled = ref(props.modelValue != null && props.modelValue > 0)
|
||||
|
||||
// Sync enabled state when modelValue changes externally (e.g. account load)
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
(val) => {
|
||||
enabled.value = val != null && val > 0
|
||||
}
|
||||
)
|
||||
|
||||
// When toggle is turned off, clear the value
|
||||
watch(enabled, (val) => {
|
||||
if (!val) {
|
||||
emit('update:modelValue', null)
|
||||
}
|
||||
})
|
||||
|
||||
const onInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:modelValue', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaLimitHint') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.quotaLimitToggle') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaLimitToggleHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="enabled = !enabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
enabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
enabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="enabled" class="space-y-3">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.quotaLimitAmount') }}</label>
|
||||
<div class="relative">
|
||||
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
|
||||
<input
|
||||
:value="modelValue"
|
||||
@input="onInput"
|
||||
type="number"
|
||||
min="0"
|
||||
step="0.01"
|
||||
class="input pl-7"
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaLimitAmountHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -18,6 +18,10 @@
|
||||
<Icon name="chart" size="sm" class="text-indigo-500" />
|
||||
{{ t('admin.accounts.viewStats') }}
|
||||
</button>
|
||||
<button @click="$emit('schedule', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="clock" size="sm" class="text-orange-500" />
|
||||
{{ t('admin.scheduledTests.schedule') }}
|
||||
</button>
|
||||
<template v-if="account.type === 'oauth' || account.type === 'setup-token'">
|
||||
<button @click="$emit('reauth', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-blue-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="link" size="sm" />
|
||||
@@ -37,6 +41,10 @@
|
||||
<Icon name="clock" size="sm" />
|
||||
{{ t('admin.accounts.clearRateLimit') }}
|
||||
</button>
|
||||
<button v-if="hasQuotaLimit" @click="$emit('reset-quota', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-teal-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="refresh" size="sm" />
|
||||
{{ t('admin.accounts.resetQuota') }}
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
@@ -51,7 +59,7 @@ import { Icon } from '@/components/icons'
|
||||
import type { Account } from '@/types'
|
||||
|
||||
const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>()
|
||||
const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit'])
|
||||
const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit', 'reset-quota'])
|
||||
const { t } = useI18n()
|
||||
const isRateLimited = computed(() => {
|
||||
if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) {
|
||||
@@ -67,6 +75,12 @@ const isRateLimited = computed(() => {
|
||||
return false
|
||||
})
|
||||
const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date())
|
||||
const hasQuotaLimit = computed(() => {
|
||||
return props.account?.type === 'apikey' &&
|
||||
props.account?.quota_limit !== undefined &&
|
||||
props.account?.quota_limit !== null &&
|
||||
props.account?.quota_limit > 0
|
||||
})
|
||||
|
||||
const handleKeydown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') emit('close')
|
||||
|
||||
@@ -25,6 +25,6 @@ const updateStatus = (value: string | number | boolean | null) => { emit('update
|
||||
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
||||
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
||||
</script>
|
||||
|
||||
587
frontend/src/components/admin/account/ScheduledTestsPanel.vue
Normal file
587
frontend/src/components/admin/account/ScheduledTestsPanel.vue
Normal file
@@ -0,0 +1,587 @@
|
||||
<template>
|
||||
<BaseDialog
|
||||
:show="show"
|
||||
:title="t('admin.scheduledTests.title')"
|
||||
width="wide"
|
||||
@close="emit('close')"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<!-- Add Plan Button -->
|
||||
<div class="flex items-center justify-between">
|
||||
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.title') }}
|
||||
</p>
|
||||
<button
|
||||
@click="showAddForm = !showAddForm"
|
||||
class="btn btn-primary flex items-center gap-1.5 text-sm"
|
||||
>
|
||||
<Icon name="plus" size="sm" :stroke-width="2" />
|
||||
{{ t('admin.scheduledTests.addPlan') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Add Plan Form -->
|
||||
<div
|
||||
v-if="showAddForm"
|
||||
class="rounded-xl border border-primary-200 bg-primary-50/50 p-4 dark:border-primary-800 dark:bg-primary-900/20"
|
||||
>
|
||||
<div class="mb-3 text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.scheduledTests.addPlan') }}
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.model') }}
|
||||
</label>
|
||||
<Select
|
||||
v-model="newPlan.model_id"
|
||||
:options="modelOptions"
|
||||
:placeholder="t('admin.scheduledTests.model')"
|
||||
:searchable="modelOptions.length > 5"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.cronExpression') }}
|
||||
</label>
|
||||
<Input
|
||||
v-model="newPlan.cron_expression"
|
||||
:placeholder="'*/30 * * * *'"
|
||||
:hint="t('admin.scheduledTests.cronHelp')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.maxResults') }}
|
||||
</label>
|
||||
<Input
|
||||
v-model="newPlan.max_results"
|
||||
type="number"
|
||||
placeholder="100"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex items-end">
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<Toggle v-model="newPlan.enabled" />
|
||||
{{ t('admin.scheduledTests.enabled') }}
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3 flex justify-end gap-2">
|
||||
<button
|
||||
@click="showAddForm = false; resetNewPlan()"
|
||||
class="rounded-lg bg-gray-100 px-3 py-1.5 text-sm font-medium text-gray-700 transition-colors hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-300 dark:hover:bg-dark-500"
|
||||
>
|
||||
{{ t('common.cancel') }}
|
||||
</button>
|
||||
<button
|
||||
@click="handleCreate"
|
||||
:disabled="!newPlan.model_id || !newPlan.cron_expression || creating"
|
||||
class="flex items-center gap-1.5 rounded-lg bg-primary-500 px-3 py-1.5 text-sm font-medium text-white transition-colors hover:bg-primary-600 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
<Icon v-if="creating" name="refresh" size="sm" class="animate-spin" :stroke-width="2" />
|
||||
{{ t('common.save') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-8">
|
||||
<Icon name="refresh" size="md" class="animate-spin text-gray-400" :stroke-width="2" />
|
||||
<span class="ml-2 text-sm text-gray-500">{{ t('common.loading') }}...</span>
|
||||
</div>
|
||||
|
||||
<!-- Empty State -->
|
||||
<div
|
||||
v-else-if="plans.length === 0"
|
||||
class="rounded-xl border border-dashed border-gray-300 py-10 text-center dark:border-dark-600"
|
||||
>
|
||||
<Icon name="calendar" size="lg" class="mx-auto mb-2 text-gray-400" :stroke-width="1.5" />
|
||||
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.noPlans') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Plans List -->
|
||||
<div v-else class="space-y-3">
|
||||
<div
|
||||
v-for="plan in plans"
|
||||
:key="plan.id"
|
||||
class="rounded-xl border border-gray-200 bg-white transition-all dark:border-dark-600 dark:bg-dark-800"
|
||||
>
|
||||
<!-- Plan Header -->
|
||||
<div
|
||||
class="flex cursor-pointer items-center justify-between px-4 py-3"
|
||||
@click="toggleExpand(plan.id)"
|
||||
>
|
||||
<div class="flex flex-1 items-center gap-4">
|
||||
<!-- Model -->
|
||||
<div class="min-w-0">
|
||||
<div class="text-sm font-medium text-gray-900 dark:text-gray-100">
|
||||
{{ plan.model_id }}
|
||||
</div>
|
||||
<div class="mt-0.5 font-mono text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ plan.cron_expression }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Enabled Toggle -->
|
||||
<div class="flex items-center gap-1.5" @click.stop>
|
||||
<Toggle
|
||||
:model-value="plan.enabled"
|
||||
@update:model-value="(val: boolean) => handleToggleEnabled(plan, val)"
|
||||
/>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ plan.enabled ? t('admin.scheduledTests.enabled') : '' }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-3">
|
||||
<!-- Last Run -->
|
||||
<div v-if="plan.last_run_at" class="hidden text-right text-xs text-gray-500 dark:text-gray-400 sm:block">
|
||||
<div>{{ t('admin.scheduledTests.lastRun') }}</div>
|
||||
<div>{{ formatDateTime(plan.last_run_at) }}</div>
|
||||
</div>
|
||||
|
||||
<!-- Next Run -->
|
||||
<div v-if="plan.next_run_at" class="hidden text-right text-xs text-gray-500 dark:text-gray-400 sm:block">
|
||||
<div>{{ t('admin.scheduledTests.nextRun') }}</div>
|
||||
<div>{{ formatDateTime(plan.next_run_at) }}</div>
|
||||
</div>
|
||||
|
||||
<!-- Actions -->
|
||||
<div class="flex items-center gap-1" @click.stop>
|
||||
<button
|
||||
@click="startEdit(plan)"
|
||||
class="rounded-lg p-1.5 text-gray-400 transition-colors hover:bg-blue-50 hover:text-blue-500 dark:hover:bg-blue-900/20"
|
||||
:title="t('admin.scheduledTests.editPlan')"
|
||||
>
|
||||
<Icon name="edit" size="sm" :stroke-width="2" />
|
||||
</button>
|
||||
<button
|
||||
@click="confirmDeletePlan(plan)"
|
||||
class="rounded-lg p-1.5 text-gray-400 transition-colors hover:bg-red-50 hover:text-red-500 dark:hover:bg-red-900/20"
|
||||
:title="t('admin.scheduledTests.deletePlan')"
|
||||
>
|
||||
<Icon name="trash" size="sm" :stroke-width="2" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expand indicator -->
|
||||
<Icon
|
||||
name="chevronDown"
|
||||
size="sm"
|
||||
:class="[
|
||||
'text-gray-400 transition-transform duration-200',
|
||||
expandedPlanId === plan.id ? 'rotate-180' : ''
|
||||
]"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Edit Form -->
|
||||
<div
|
||||
v-if="editingPlanId === plan.id"
|
||||
class="border-t border-blue-100 bg-blue-50/50 px-4 py-3 dark:border-blue-900 dark:bg-blue-900/10"
|
||||
@click.stop
|
||||
>
|
||||
<div class="mb-2 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.editPlan') }}
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.model') }}
|
||||
</label>
|
||||
<Select
|
||||
v-model="editForm.model_id"
|
||||
:options="modelOptions"
|
||||
:placeholder="t('admin.scheduledTests.model')"
|
||||
:searchable="modelOptions.length > 5"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.cronExpression') }}
|
||||
</label>
|
||||
<Input
|
||||
v-model="editForm.cron_expression"
|
||||
:placeholder="'*/30 * * * *'"
|
||||
:hint="t('admin.scheduledTests.cronHelp')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.maxResults') }}
|
||||
</label>
|
||||
<Input
|
||||
v-model="editForm.max_results"
|
||||
type="number"
|
||||
placeholder="100"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex items-end">
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<Toggle v-model="editForm.enabled" />
|
||||
{{ t('admin.scheduledTests.enabled') }}
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3 flex justify-end gap-2">
|
||||
<button
|
||||
@click="cancelEdit"
|
||||
class="rounded-lg bg-gray-100 px-3 py-1.5 text-sm font-medium text-gray-700 transition-colors hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-300 dark:hover:bg-dark-500"
|
||||
>
|
||||
{{ t('common.cancel') }}
|
||||
</button>
|
||||
<button
|
||||
@click="handleEdit"
|
||||
:disabled="!editForm.model_id || !editForm.cron_expression || updating"
|
||||
class="flex items-center gap-1.5 rounded-lg bg-primary-500 px-3 py-1.5 text-sm font-medium text-white transition-colors hover:bg-primary-600 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
<Icon v-if="updating" name="refresh" size="sm" class="animate-spin" :stroke-width="2" />
|
||||
{{ t('common.save') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Expanded Results Section -->
|
||||
<div
|
||||
v-if="expandedPlanId === plan.id"
|
||||
class="border-t border-gray-100 px-4 py-3 dark:border-dark-700"
|
||||
>
|
||||
<div class="mb-2 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.results') }}
|
||||
</div>
|
||||
|
||||
<!-- Results Loading -->
|
||||
<div v-if="loadingResults" class="flex items-center justify-center py-4">
|
||||
<Icon name="refresh" size="sm" class="animate-spin text-gray-400" :stroke-width="2" />
|
||||
<span class="ml-2 text-xs text-gray-500">{{ t('common.loading') }}...</span>
|
||||
</div>
|
||||
|
||||
<!-- No Results -->
|
||||
<div
|
||||
v-else-if="results.length === 0"
|
||||
class="py-4 text-center text-xs text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
{{ t('admin.scheduledTests.noResults') }}
|
||||
</div>
|
||||
|
||||
<!-- Results List -->
|
||||
<div v-else class="max-h-64 space-y-2 overflow-y-auto">
|
||||
<div
|
||||
v-for="result in results"
|
||||
:key="result.id"
|
||||
class="rounded-lg border border-gray-100 bg-gray-50 p-3 dark:border-dark-700 dark:bg-dark-900"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- Status Badge -->
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium',
|
||||
result.status === 'success'
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-500/20 dark:text-green-400'
|
||||
: result.status === 'running'
|
||||
? 'bg-blue-100 text-blue-700 dark:bg-blue-500/20 dark:text-blue-400'
|
||||
: 'bg-red-100 text-red-700 dark:bg-red-500/20 dark:text-red-400'
|
||||
]"
|
||||
>
|
||||
{{
|
||||
result.status === 'success'
|
||||
? t('admin.scheduledTests.success')
|
||||
: result.status === 'running'
|
||||
? t('admin.scheduledTests.running')
|
||||
: t('admin.scheduledTests.failed')
|
||||
}}
|
||||
</span>
|
||||
|
||||
<!-- Latency -->
|
||||
<span v-if="result.latency_ms > 0" class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ result.latency_ms }}ms
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Started At -->
|
||||
<span class="text-xs text-gray-400">
|
||||
{{ formatDateTime(result.started_at) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Response / Error (collapsible) -->
|
||||
<div v-if="result.error_message" class="mt-2">
|
||||
<div
|
||||
class="cursor-pointer text-xs font-medium text-red-600 dark:text-red-400"
|
||||
@click="toggleResultDetail(result.id)"
|
||||
>
|
||||
{{ t('admin.scheduledTests.errorMessage') }}
|
||||
<Icon
|
||||
name="chevronDown"
|
||||
size="sm"
|
||||
:class="[
|
||||
'inline transition-transform duration-200',
|
||||
expandedResultIds.has(result.id) ? 'rotate-180' : ''
|
||||
]"
|
||||
/>
|
||||
</div>
|
||||
<pre
|
||||
v-if="expandedResultIds.has(result.id)"
|
||||
class="mt-1 max-h-32 overflow-auto whitespace-pre-wrap rounded bg-red-50 p-2 text-xs text-red-700 dark:bg-red-900/20 dark:text-red-300"
|
||||
>{{ result.error_message }}</pre>
|
||||
</div>
|
||||
<div v-else-if="result.response_text" class="mt-2">
|
||||
<div
|
||||
class="cursor-pointer text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
@click="toggleResultDetail(result.id)"
|
||||
>
|
||||
{{ t('admin.scheduledTests.responseText') }}
|
||||
<Icon
|
||||
name="chevronDown"
|
||||
size="sm"
|
||||
:class="[
|
||||
'inline transition-transform duration-200',
|
||||
expandedResultIds.has(result.id) ? 'rotate-180' : ''
|
||||
]"
|
||||
/>
|
||||
</div>
|
||||
<pre
|
||||
v-if="expandedResultIds.has(result.id)"
|
||||
class="mt-1 max-h-32 overflow-auto whitespace-pre-wrap rounded bg-gray-100 p-2 text-xs text-gray-700 dark:bg-dark-800 dark:text-gray-300"
|
||||
>{{ result.response_text }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Delete Confirmation -->
|
||||
<ConfirmDialog
|
||||
:show="showDeleteConfirm"
|
||||
:title="t('admin.scheduledTests.deletePlan')"
|
||||
:message="t('admin.scheduledTests.confirmDelete')"
|
||||
:confirm-text="t('common.delete')"
|
||||
:cancel-text="t('common.cancel')"
|
||||
:danger="true"
|
||||
@confirm="handleDelete"
|
||||
@cancel="showDeleteConfirm = false"
|
||||
/>
|
||||
</BaseDialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select, { type SelectOption } from '@/components/common/Select.vue'
|
||||
import Input from '@/components/common/Input.vue'
|
||||
import Toggle from '@/components/common/Toggle.vue'
|
||||
import { Icon } from '@/components/icons'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
import type { ScheduledTestPlan, ScheduledTestResult } from '@/types'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
const props = defineProps<{
|
||||
show: boolean
|
||||
accountId: number | null
|
||||
modelOptions: SelectOption[]
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
(e: 'close'): void
|
||||
}>()
|
||||
|
||||
// State
|
||||
const loading = ref(false)
|
||||
const creating = ref(false)
|
||||
const loadingResults = ref(false)
|
||||
const plans = ref<ScheduledTestPlan[]>([])
|
||||
const results = ref<ScheduledTestResult[]>([])
|
||||
const expandedPlanId = ref<number | null>(null)
|
||||
const expandedResultIds = reactive(new Set<number>())
|
||||
const showAddForm = ref(false)
|
||||
const showDeleteConfirm = ref(false)
|
||||
const deletingPlan = ref<ScheduledTestPlan | null>(null)
|
||||
const editingPlanId = ref<number | null>(null)
|
||||
const updating = ref(false)
|
||||
const editForm = reactive({
|
||||
model_id: '' as string,
|
||||
cron_expression: '' as string,
|
||||
max_results: '100' as string,
|
||||
enabled: true
|
||||
})
|
||||
|
||||
const newPlan = reactive({
|
||||
model_id: '' as string,
|
||||
cron_expression: '' as string,
|
||||
max_results: '100' as string,
|
||||
enabled: true
|
||||
})
|
||||
|
||||
const resetNewPlan = () => {
|
||||
newPlan.model_id = ''
|
||||
newPlan.cron_expression = ''
|
||||
newPlan.max_results = '100'
|
||||
newPlan.enabled = true
|
||||
}
|
||||
|
||||
// Load plans when dialog opens
|
||||
watch(
|
||||
() => props.show,
|
||||
async (visible) => {
|
||||
if (visible && props.accountId) {
|
||||
await loadPlans()
|
||||
} else {
|
||||
plans.value = []
|
||||
results.value = []
|
||||
expandedPlanId.value = null
|
||||
expandedResultIds.clear()
|
||||
showAddForm.value = false
|
||||
showDeleteConfirm.value = false
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
const loadPlans = async () => {
|
||||
if (!props.accountId) return
|
||||
loading.value = true
|
||||
try {
|
||||
plans.value = await adminAPI.scheduledTests.listByAccount(props.accountId)
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to load plans')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreate = async () => {
|
||||
if (!props.accountId || !newPlan.model_id || !newPlan.cron_expression) return
|
||||
creating.value = true
|
||||
try {
|
||||
const maxResults = Number(newPlan.max_results) || 100
|
||||
await adminAPI.scheduledTests.create({
|
||||
account_id: props.accountId,
|
||||
model_id: newPlan.model_id,
|
||||
cron_expression: newPlan.cron_expression,
|
||||
enabled: newPlan.enabled,
|
||||
max_results: maxResults
|
||||
})
|
||||
appStore.showSuccess(t('admin.scheduledTests.createSuccess'))
|
||||
showAddForm.value = false
|
||||
resetNewPlan()
|
||||
await loadPlans()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to create plan')
|
||||
} finally {
|
||||
creating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleToggleEnabled = async (plan: ScheduledTestPlan, enabled: boolean) => {
|
||||
try {
|
||||
const updated = await adminAPI.scheduledTests.update(plan.id, { enabled })
|
||||
const index = plans.value.findIndex((p) => p.id === plan.id)
|
||||
if (index !== -1) {
|
||||
plans.value[index] = updated
|
||||
}
|
||||
appStore.showSuccess(t('admin.scheduledTests.updateSuccess'))
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to update plan')
|
||||
}
|
||||
}
|
||||
|
||||
const startEdit = (plan: ScheduledTestPlan) => {
|
||||
editingPlanId.value = plan.id
|
||||
editForm.model_id = plan.model_id
|
||||
editForm.cron_expression = plan.cron_expression
|
||||
editForm.max_results = String(plan.max_results)
|
||||
editForm.enabled = plan.enabled
|
||||
}
|
||||
|
||||
const cancelEdit = () => {
|
||||
editingPlanId.value = null
|
||||
}
|
||||
|
||||
const handleEdit = async () => {
|
||||
if (!editingPlanId.value || !editForm.model_id || !editForm.cron_expression) return
|
||||
updating.value = true
|
||||
try {
|
||||
const updated = await adminAPI.scheduledTests.update(editingPlanId.value, {
|
||||
model_id: editForm.model_id,
|
||||
cron_expression: editForm.cron_expression,
|
||||
max_results: Number(editForm.max_results) || 100,
|
||||
enabled: editForm.enabled
|
||||
})
|
||||
const index = plans.value.findIndex((p) => p.id === editingPlanId.value)
|
||||
if (index !== -1) {
|
||||
plans.value[index] = updated
|
||||
}
|
||||
appStore.showSuccess(t('admin.scheduledTests.updateSuccess'))
|
||||
editingPlanId.value = null
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to update plan')
|
||||
} finally {
|
||||
updating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const confirmDeletePlan = (plan: ScheduledTestPlan) => {
|
||||
deletingPlan.value = plan
|
||||
showDeleteConfirm.value = true
|
||||
}
|
||||
|
||||
const handleDelete = async () => {
|
||||
if (!deletingPlan.value) return
|
||||
try {
|
||||
await adminAPI.scheduledTests.delete(deletingPlan.value.id)
|
||||
appStore.showSuccess(t('admin.scheduledTests.deleteSuccess'))
|
||||
plans.value = plans.value.filter((p) => p.id !== deletingPlan.value!.id)
|
||||
if (expandedPlanId.value === deletingPlan.value.id) {
|
||||
expandedPlanId.value = null
|
||||
results.value = []
|
||||
}
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to delete plan')
|
||||
} finally {
|
||||
showDeleteConfirm.value = false
|
||||
deletingPlan.value = null
|
||||
}
|
||||
}
|
||||
|
||||
const toggleExpand = async (planId: number) => {
|
||||
if (expandedPlanId.value === planId) {
|
||||
expandedPlanId.value = null
|
||||
results.value = []
|
||||
expandedResultIds.clear()
|
||||
return
|
||||
}
|
||||
|
||||
expandedPlanId.value = planId
|
||||
expandedResultIds.clear()
|
||||
loadingResults.value = true
|
||||
try {
|
||||
results.value = await adminAPI.scheduledTests.listResults(planId, 20)
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || 'Failed to load results')
|
||||
results.value = []
|
||||
} finally {
|
||||
loadingResults.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const toggleResultDetail = (resultId: number) => {
|
||||
if (expandedResultIds.has(resultId)) {
|
||||
expandedResultIds.delete(resultId)
|
||||
} else {
|
||||
expandedResultIds.add(resultId)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -63,7 +63,8 @@ const chartColors = computed(() => ({
|
||||
grid: isDarkMode.value ? '#374151' : '#e5e7eb',
|
||||
input: '#3b82f6',
|
||||
output: '#10b981',
|
||||
cache: '#f59e0b'
|
||||
cacheCreation: '#f59e0b',
|
||||
cacheRead: '#06b6d4'
|
||||
}))
|
||||
|
||||
const chartData = computed(() => {
|
||||
@@ -89,10 +90,18 @@ const chartData = computed(() => {
|
||||
tension: 0.3
|
||||
},
|
||||
{
|
||||
label: 'Cache',
|
||||
data: props.trendData.map((d) => d.cache_tokens),
|
||||
borderColor: chartColors.value.cache,
|
||||
backgroundColor: `${chartColors.value.cache}20`,
|
||||
label: 'Cache Creation',
|
||||
data: props.trendData.map((d) => d.cache_creation_tokens),
|
||||
borderColor: chartColors.value.cacheCreation,
|
||||
backgroundColor: `${chartColors.value.cacheCreation}20`,
|
||||
fill: true,
|
||||
tension: 0.3
|
||||
},
|
||||
{
|
||||
label: 'Cache Read',
|
||||
data: props.trendData.map((d) => d.cache_read_tokens),
|
||||
borderColor: chartColors.value.cacheRead,
|
||||
backgroundColor: `${chartColors.value.cacheRead}20`,
|
||||
fill: true,
|
||||
tension: 0.3
|
||||
}
|
||||
|
||||
@@ -443,7 +443,22 @@ $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
|
||||
content = ''
|
||||
}
|
||||
|
||||
return [{ path, content }]
|
||||
const vscodeSettingsPath = activeTab.value === 'unix'
|
||||
? '~/.claude/settings.json'
|
||||
: '%userprofile%\\.claude\\settings.json'
|
||||
|
||||
const vscodeContent = `{
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "${baseUrl}",
|
||||
"ANTHROPIC_AUTH_TOKEN": "${apiKey}",
|
||||
"CLAUDE_CODE_ATTRIBUTION_HEADER": "0"
|
||||
}
|
||||
}`
|
||||
|
||||
return [
|
||||
{ path, content },
|
||||
{ path: vscodeSettingsPath, content: vscodeContent, hint: 'VSCode Claude Code' }
|
||||
]
|
||||
}
|
||||
|
||||
function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig {
|
||||
@@ -496,16 +511,18 @@ function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] {
|
||||
const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex'
|
||||
|
||||
// config.toml content
|
||||
const configContent = `model_provider = "sub2api"
|
||||
model = "gpt-5.3-codex"
|
||||
model_reasoning_effort = "high"
|
||||
network_access = "enabled"
|
||||
const configContent = `model_provider = "OpenAI"
|
||||
model = "gpt-5.4"
|
||||
review_model = "gpt-5.4"
|
||||
model_reasoning_effort = "xhigh"
|
||||
disable_response_storage = true
|
||||
network_access = "enabled"
|
||||
windows_wsl_setup_acknowledged = true
|
||||
model_verbosity = "high"
|
||||
model_context_window = 1000000
|
||||
model_auto_compact_token_limit = 900000
|
||||
|
||||
[model_providers.sub2api]
|
||||
name = "sub2api"
|
||||
[model_providers.OpenAI]
|
||||
name = "OpenAI"
|
||||
base_url = "${baseUrl}"
|
||||
wire_api = "responses"
|
||||
requires_openai_auth = true`
|
||||
@@ -533,16 +550,18 @@ function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] {
|
||||
const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex'
|
||||
|
||||
// config.toml content with WebSocket v2
|
||||
const configContent = `model_provider = "sub2api"
|
||||
model = "gpt-5.3-codex"
|
||||
model_reasoning_effort = "high"
|
||||
network_access = "enabled"
|
||||
const configContent = `model_provider = "OpenAI"
|
||||
model = "gpt-5.4"
|
||||
review_model = "gpt-5.4"
|
||||
model_reasoning_effort = "xhigh"
|
||||
disable_response_storage = true
|
||||
network_access = "enabled"
|
||||
windows_wsl_setup_acknowledged = true
|
||||
model_verbosity = "high"
|
||||
model_context_window = 1000000
|
||||
model_auto_compact_token_limit = 900000
|
||||
|
||||
[model_providers.sub2api]
|
||||
name = "sub2api"
|
||||
[model_providers.OpenAI]
|
||||
name = "OpenAI"
|
||||
base_url = "${baseUrl}"
|
||||
wire_api = "responses"
|
||||
supports_websockets = true
|
||||
@@ -655,6 +674,22 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
||||
xhigh: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.4': {
|
||||
name: 'GPT-5.4',
|
||||
limit: {
|
||||
context: 1050000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {},
|
||||
xhigh: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.3-codex-spark': {
|
||||
name: 'GPT-5.3 Codex Spark',
|
||||
limit: {
|
||||
|
||||
@@ -2,6 +2,13 @@ import { describe, expect, it } from 'vitest'
|
||||
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
|
||||
|
||||
describe('useModelWhitelist', () => {
|
||||
it('openai 模型列表包含 GPT-5.4 官方快照', () => {
|
||||
const models = getModelsByPlatform('openai')
|
||||
|
||||
expect(models).toContain('gpt-5.4')
|
||||
expect(models).toContain('gpt-5.4-2026-03-05')
|
||||
})
|
||||
|
||||
it('antigravity 模型列表包含图片模型兼容项', () => {
|
||||
const models = getModelsByPlatform('antigravity')
|
||||
|
||||
@@ -15,4 +22,12 @@ describe('useModelWhitelist', () => {
|
||||
'gemini-3.1-flash-image': 'gemini-3.1-flash-image'
|
||||
})
|
||||
})
|
||||
|
||||
it('whitelist 模式会保留 GPT-5.4 官方快照的精确映射', () => {
|
||||
const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-2026-03-05'], [])
|
||||
|
||||
expect(mapping).toEqual({
|
||||
'gpt-5.4-2026-03-05': 'gpt-5.4-2026-03-05'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user