mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-21 07:04:45 +08:00
Merge branch 'test' into release
This commit is contained in:
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: |
|
run: |
|
||||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||||
gosec -severity high -confidence high ./...
|
gosec -conf .gosec.json -severity high -confidence high ./...
|
||||||
|
|
||||||
frontend-security:
|
frontend-security:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
5
backend/.gosec.json
Normal file
5
backend/.gosec.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"global": {
|
||||||
|
"exclude": "G704"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1 +1 @@
|
|||||||
0.1.74.9
|
0.1.83.2
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
|
||||||
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
|
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||||
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
||||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||||
|
|||||||
@@ -669,6 +669,7 @@ var (
|
|||||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||||
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
|
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
|
||||||
|
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "api_key_id", Type: field.TypeInt64},
|
{Name: "api_key_id", Type: field.TypeInt64},
|
||||||
{Name: "account_id", Type: field.TypeInt64},
|
{Name: "account_id", Type: field.TypeInt64},
|
||||||
@@ -684,31 +685,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -717,32 +718,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -757,12 +758,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15980,6 +15980,7 @@ type UsageLogMutation struct {
|
|||||||
addimage_count *int
|
addimage_count *int
|
||||||
image_size *string
|
image_size *string
|
||||||
media_type *string
|
media_type *string
|
||||||
|
cache_ttl_overridden *bool
|
||||||
created_at *time.Time
|
created_at *time.Time
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
user *int64
|
user *int64
|
||||||
@@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() {
|
|||||||
delete(m.clearedFields, usagelog.FieldMediaType)
|
delete(m.clearedFields, usagelog.FieldMediaType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
||||||
|
m.cache_ttl_overridden = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) {
|
||||||
|
v := m.cache_ttl_overridden
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.CacheTTLOverridden, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field.
|
||||||
|
func (m *UsageLogMutation) ResetCacheTTLOverridden() {
|
||||||
|
m.cache_ttl_overridden = nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetCreatedAt sets the "created_at" field.
|
// SetCreatedAt sets the "created_at" field.
|
||||||
func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
|
func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
|
||||||
m.created_at = &t
|
m.created_at = &t
|
||||||
@@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 31)
|
fields := make([]string, 0, 32)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.media_type != nil {
|
if m.media_type != nil {
|
||||||
fields = append(fields, usagelog.FieldMediaType)
|
fields = append(fields, usagelog.FieldMediaType)
|
||||||
}
|
}
|
||||||
|
if m.cache_ttl_overridden != nil {
|
||||||
|
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
||||||
|
}
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, usagelog.FieldCreatedAt)
|
fields = append(fields, usagelog.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ImageSize()
|
return m.ImageSize()
|
||||||
case usagelog.FieldMediaType:
|
case usagelog.FieldMediaType:
|
||||||
return m.MediaType()
|
return m.MediaType()
|
||||||
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
|
return m.CacheTTLOverridden()
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
return m.CreatedAt()
|
return m.CreatedAt()
|
||||||
}
|
}
|
||||||
@@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldImageSize(ctx)
|
return m.OldImageSize(ctx)
|
||||||
case usagelog.FieldMediaType:
|
case usagelog.FieldMediaType:
|
||||||
return m.OldMediaType(ctx)
|
return m.OldMediaType(ctx)
|
||||||
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
|
return m.OldCacheTTLOverridden(ctx)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
return m.OldCreatedAt(ctx)
|
return m.OldCreatedAt(ctx)
|
||||||
}
|
}
|
||||||
@@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetMediaType(v)
|
m.SetMediaType(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetCacheTTLOverridden(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
v, ok := value.(time.Time)
|
v, ok := value.(time.Time)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldMediaType:
|
case usagelog.FieldMediaType:
|
||||||
m.ResetMediaType()
|
m.ResetMediaType()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
|
m.ResetCacheTTLOverridden()
|
||||||
|
return nil
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
m.ResetCreatedAt()
|
m.ResetCreatedAt()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -821,8 +821,12 @@ func init() {
|
|||||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
||||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||||
|
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||||
|
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
||||||
|
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||||
|
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[30].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|
||||||
|
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||||
|
field.Bool("cache_ttl_overridden").
|
||||||
|
Default(false),
|
||||||
|
|
||||||
// 时间戳(只有 created_at,日志不可修改)
|
// 时间戳(只有 created_at,日志不可修改)
|
||||||
field.Time("created_at").
|
field.Time("created_at").
|
||||||
Default(time.Now).
|
Default(time.Now).
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ type UsageLog struct {
|
|||||||
ImageSize *string `json:"image_size,omitempty"`
|
ImageSize *string `json:"image_size,omitempty"`
|
||||||
// MediaType holds the value of the "media_type" field.
|
// MediaType holds the value of the "media_type" field.
|
||||||
MediaType *string `json:"media_type,omitempty"`
|
MediaType *string `json:"media_type,omitempty"`
|
||||||
|
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||||
|
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||||
// CreatedAt holds the value of the "created_at" field.
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
@@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
_m.MediaType = new(string)
|
_m.MediaType = new(string)
|
||||||
*_m.MediaType = value.String
|
*_m.MediaType = value.String
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.CacheTTLOverridden = value.Bool
|
||||||
|
}
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||||
@@ -562,6 +570,9 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString(*v)
|
builder.WriteString(*v)
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("cache_ttl_overridden=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("created_at=")
|
builder.WriteString("created_at=")
|
||||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ const (
|
|||||||
FieldImageSize = "image_size"
|
FieldImageSize = "image_size"
|
||||||
// FieldMediaType holds the string denoting the media_type field in the database.
|
// FieldMediaType holds the string denoting the media_type field in the database.
|
||||||
FieldMediaType = "media_type"
|
FieldMediaType = "media_type"
|
||||||
|
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||||
|
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
FieldCreatedAt = "created_at"
|
FieldCreatedAt = "created_at"
|
||||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||||
@@ -158,6 +160,7 @@ var Columns = []string{
|
|||||||
FieldImageCount,
|
FieldImageCount,
|
||||||
FieldImageSize,
|
FieldImageSize,
|
||||||
FieldMediaType,
|
FieldMediaType,
|
||||||
|
FieldCacheTTLOverridden,
|
||||||
FieldCreatedAt,
|
FieldCreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,6 +219,8 @@ var (
|
|||||||
ImageSizeValidator func(string) error
|
ImageSizeValidator func(string) error
|
||||||
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||||
MediaTypeValidator func(string) error
|
MediaTypeValidator func(string) error
|
||||||
|
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||||
|
DefaultCacheTTLOverridden bool
|
||||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
DefaultCreatedAt func() time.Time
|
DefaultCreatedAt func() time.Time
|
||||||
)
|
)
|
||||||
@@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
|
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||||
|
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByCreatedAt orders the results by the created_at field.
|
// ByCreatedAt orders the results by the created_at field.
|
||||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||||
|
|||||||
@@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||||
|
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||||
func CreatedAt(v time.Time) predicate.UsageLog {
|
func CreatedAt(v time.Time) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||||
|
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
|
||||||
|
func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.UsageLog {
|
func CreatedAtEQ(v time.Time) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
|||||||
@@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||||
|
_c.mutation.SetCacheTTLOverridden(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetCacheTTLOverridden(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetCreatedAt sets the "created_at" field.
|
// SetCreatedAt sets the "created_at" field.
|
||||||
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
|
||||||
_c.mutation.SetCreatedAt(v)
|
_c.mutation.SetCreatedAt(v)
|
||||||
@@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() {
|
|||||||
v := usagelog.DefaultImageCount
|
v := usagelog.DefaultImageCount
|
||||||
_c.mutation.SetImageCount(v)
|
_c.mutation.SetImageCount(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||||
|
v := usagelog.DefaultCacheTTLOverridden
|
||||||
|
_c.mutation.SetCacheTTLOverridden(v)
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
v := usagelog.DefaultCreatedAt()
|
v := usagelog.DefaultCreatedAt()
|
||||||
_c.mutation.SetCreatedAt(v)
|
_c.mutation.SetCreatedAt(v)
|
||||||
@@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||||
|
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
|
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
|
||||||
}
|
}
|
||||||
@@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
||||||
_node.MediaType = &value
|
_node.MediaType = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
|
_node.CacheTTLOverridden = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.CreatedAt(); ok {
|
if value, ok := _c.mutation.CreatedAt(); ok {
|
||||||
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
|
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
|
||||||
_node.CreatedAt = value
|
_node.CreatedAt = value
|
||||||
@@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldCacheTTLOverridden)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetCacheTTLOverridden(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateCacheTTLOverridden()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
|
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetCacheTTLOverridden(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateCacheTTLOverridden()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||||
|
_u.mutation.SetCacheTTLOverridden(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetCacheTTLOverridden(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.MediaTypeCleared() {
|
if _u.mutation.MediaTypeCleared() {
|
||||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
@@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.SetCacheTTLOverridden(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetCacheTTLOverridden(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if _u.mutation.MediaTypeCleared() {
|
if _u.mutation.MediaTypeCleared() {
|
||||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
|
|||||||
@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
|
|||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
// 重试退避基础时间(秒)
|
// 重试退避基础时间(秒)
|
||||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||||
|
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||||
|
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingConfig struct {
|
type PricingConfig struct {
|
||||||
@@ -269,17 +271,30 @@ type SoraConfig struct {
|
|||||||
|
|
||||||
// SoraClientConfig 直连 Sora 客户端配置
|
// SoraClientConfig 直连 Sora 客户端配置
|
||||||
type SoraClientConfig struct {
|
type SoraClientConfig struct {
|
||||||
BaseURL string `mapstructure:"base_url"`
|
BaseURL string `mapstructure:"base_url"`
|
||||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
||||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||||
Debug bool `mapstructure:"debug"`
|
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||||
Headers map[string]string `mapstructure:"headers"`
|
Debug bool `mapstructure:"debug"`
|
||||||
UserAgent string `mapstructure:"user_agent"`
|
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
Headers map[string]string `mapstructure:"headers"`
|
||||||
|
UserAgent string `mapstructure:"user_agent"`
|
||||||
|
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||||
|
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
||||||
|
type SoraCurlCFFISidecarConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
Impersonate string `mapstructure:"impersonate"`
|
||||||
|
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||||
|
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
||||||
|
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraStorageConfig 媒体存储配置
|
// SoraStorageConfig 媒体存储配置
|
||||||
@@ -1111,14 +1126,22 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||||
viper.SetDefault("sora.client.max_retries", 3)
|
viper.SetDefault("sora.client.max_retries", 3)
|
||||||
|
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
||||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||||
viper.SetDefault("sora.client.debug", false)
|
viper.SetDefault("sora.client.debug", false)
|
||||||
|
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
||||||
|
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
||||||
|
|
||||||
viper.SetDefault("sora.storage.type", "local")
|
viper.SetDefault("sora.storage.type", "local")
|
||||||
viper.SetDefault("sora.storage.local_path", "")
|
viper.SetDefault("sora.storage.local_path", "")
|
||||||
@@ -1137,6 +1160,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
|
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||||
|
|
||||||
// Gemini OAuth - configure via environment variables or config file
|
// Gemini OAuth - configure via environment variables or config file
|
||||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||||
@@ -1505,6 +1529,9 @@ func (c *Config) Validate() error {
|
|||||||
if c.Sora.Client.MaxRetries < 0 {
|
if c.Sora.Client.MaxRetries < 0 {
|
||||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
||||||
|
}
|
||||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||||
}
|
}
|
||||||
@@ -1521,6 +1548,18 @@ func (c *Config) Validate() error {
|
|||||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||||
}
|
}
|
||||||
|
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
||||||
|
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
||||||
|
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
||||||
|
}
|
||||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||||
|
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
||||||
|
}
|
||||||
|
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
||||||
|
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
||||||
|
}
|
||||||
|
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
||||||
|
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
||||||
|
}
|
||||||
|
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
||||||
|
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
||||||
|
}
|
||||||
|
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
||||||
|
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
||||||
|
}
|
||||||
|
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
||||||
|
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
||||||
|
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
||||||
|
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
||||||
|
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
||||||
|
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
|||||||
pageSize := dataPageCap
|
pageSize := dataPageCap
|
||||||
var out []service.Account
|
var out []service.Account
|
||||||
for {
|
for {
|
||||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
search = search[:100]
|
search = search[:100]
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
var groupID int64
|
||||||
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
|
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Sora accounts
|
||||||
|
if account.Platform == service.PlatformSora {
|
||||||
|
response.Success(c, service.DefaultSoraModels(nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Claude/Anthropic accounts
|
// Handle Claude/Anthropic accounts
|
||||||
// For OAuth and Setup-Token accounts: return default models
|
// For OAuth and Setup-Token accounts: return default models
|
||||||
if account.IsOAuth() {
|
if account.IsOAuth() {
|
||||||
@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
|||||||
accounts := make([]*service.Account, 0)
|
accounts := make([]*service.Account, 0)
|
||||||
|
|
||||||
if len(req.AccountIDs) == 0 {
|
if len(req.AccountIDs) == 0 {
|
||||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||||
|
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
|
||||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||||
|
|
||||||
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
|
|||||||
router.ServeHTTP(rec, req)
|
router.ServeHTTP(rec, req)
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||||
router.ServeHTTP(rec, req)
|
router.ServeHTTP(rec, req)
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
|
|||||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
|
||||||
|
return &service.ProxyQualityCheckResult{
|
||||||
|
ProxyID: id,
|
||||||
|
Score: 95,
|
||||||
|
Grade: "A",
|
||||||
|
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
|
||||||
|
PassedCount: 5,
|
||||||
|
WarnCount: 0,
|
||||||
|
FailedCount: 0,
|
||||||
|
ChallengeCount: 0,
|
||||||
|
CheckedAt: time.Now().Unix(),
|
||||||
|
Items: []service.ProxyQualityCheckItem{
|
||||||
|
{Target: "base_connectivity", Status: "pass", Message: "ok"},
|
||||||
|
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||||
|
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||||
|
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||||
|
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||||
return s.redeems, int64(len(s.redeems)), nil
|
return s.redeems, int64(len(s.redeems)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
|
|||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oauthPlatformFromPath(c *gin.Context) string {
|
||||||
|
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||||
|
return service.PlatformSora
|
||||||
|
}
|
||||||
|
return service.PlatformOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||||
return &OpenAIOAuthHandler{
|
return &OpenAIOAuthHandler{
|
||||||
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
|||||||
type OpenAIExchangeCodeRequest struct {
|
type OpenAIExchangeCodeRequest struct {
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
Code string `json:"code" binding:"required"`
|
Code string `json:"code" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
}
|
}
|
||||||
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
SessionID: req.SessionID,
|
SessionID: req.SessionID,
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
|
State: req.State,
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
})
|
})
|
||||||
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
|
|
||||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||||
type OpenAIRefreshTokenRequest struct {
|
type OpenAIRefreshTokenRequest struct {
|
||||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
RT string `json:"rt"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken refreshes an OpenAI OAuth token
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
// POST /api/v1/admin/openai/refresh-token
|
// POST /api/v1/admin/openai/refresh-token
|
||||||
|
// POST /api/v1/admin/sora/rt2at
|
||||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
var req OpenAIRefreshTokenRequest
|
var req OpenAIRefreshTokenRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||||
|
if refreshToken == "" {
|
||||||
|
refreshToken = strings.TrimSpace(req.RT)
|
||||||
|
}
|
||||||
|
if refreshToken == "" {
|
||||||
|
response.BadRequest(c, "refresh_token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if req.ProxyID != nil {
|
if req.ProxyID != nil {
|
||||||
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
response.Success(c, tokenInfo)
|
response.Success(c, tokenInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||||
|
// POST /api/v1/admin/sora/st2at
|
||||||
|
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionToken string `json:"session_token"`
|
||||||
|
ST string `json:"st"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||||
|
if sessionToken == "" {
|
||||||
|
sessionToken = strings.TrimSpace(req.ST)
|
||||||
|
}
|
||||||
|
if sessionToken == "" {
|
||||||
|
response.BadRequest(c, "session_token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||||
|
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure account is OpenAI platform
|
platform := oauthPlatformFromPath(c)
|
||||||
if !account.IsOpenAI() {
|
if account.Platform != platform {
|
||||||
response.BadRequest(c, "Account is not an OpenAI account")
|
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
|||||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||||
// POST /api/v1/admin/openai/create-from-oauth
|
// POST /api/v1/admin/openai/create-from-oauth
|
||||||
|
// POST /api/v1/admin/sora/create-from-oauth
|
||||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
Code string `json:"code" binding:"required"`
|
Code string `json:"code" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
SessionID: req.SessionID,
|
SessionID: req.SessionID,
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
|
State: req.State,
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
})
|
})
|
||||||
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
// Build credentials from token info
|
// Build credentials from token info
|
||||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
platform := oauthPlatformFromPath(c)
|
||||||
|
|
||||||
// Use email as default name if not provided
|
// Use email as default name if not provided
|
||||||
name := req.Name
|
name := req.Name
|
||||||
if name == "" && tokenInfo.Email != "" {
|
if name == "" && tokenInfo.Email != "" {
|
||||||
name = tokenInfo.Email
|
name = tokenInfo.Email
|
||||||
}
|
}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = "OpenAI OAuth Account"
|
if platform == service.PlatformSora {
|
||||||
|
name = "Sora OAuth Account"
|
||||||
|
} else {
|
||||||
|
name = "OpenAI OAuth Account"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create account
|
// Create account
|
||||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||||
Name: name,
|
Name: name,
|
||||||
Platform: "openai",
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
|
|||||||
@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckQuality handles checking proxy quality across common AI targets.
|
||||||
|
// POST /api/v1/admin/proxies/:id/quality-check
|
||||||
|
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
|
||||||
|
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid proxy ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// GetStats handles getting proxy statistics
|
// GetStats handles getting proxy statistics
|
||||||
// GET /api/v1/admin/proxies/:id/stats
|
// GET /api/v1/admin/proxies/:id/stats
|
||||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||||
|
|||||||
@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
enabled := true
|
enabled := true
|
||||||
out.EnableSessionIDMasking = &enabled
|
out.EnableSessionIDMasking = &enabled
|
||||||
}
|
}
|
||||||
|
// 缓存 TTL 强制替换
|
||||||
|
if a.IsCacheTTLOverrideEnabled() {
|
||||||
|
enabled := true
|
||||||
|
out.CacheTTLOverrideEnabled = &enabled
|
||||||
|
target := a.GetCacheTTLOverrideTarget()
|
||||||
|
out.CacheTTLOverrideTarget = &target
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
|||||||
CountryCode: p.CountryCode,
|
CountryCode: p.CountryCode,
|
||||||
Region: p.Region,
|
Region: p.Region,
|
||||||
City: p.City,
|
City: p.City,
|
||||||
|
QualityStatus: p.QualityStatus,
|
||||||
|
QualityScore: p.QualityScore,
|
||||||
|
QualityGrade: p.QualityGrade,
|
||||||
|
QualitySummary: p.QualitySummary,
|
||||||
|
QualityChecked: p.QualityChecked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
ImageSize: l.ImageSize,
|
ImageSize: l.ImageSize,
|
||||||
MediaType: l.MediaType,
|
MediaType: l.MediaType,
|
||||||
UserAgent: l.UserAgent,
|
UserAgent: l.UserAgent,
|
||||||
|
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||||
CreatedAt: l.CreatedAt,
|
CreatedAt: l.CreatedAt,
|
||||||
User: UserFromServiceShallow(l.User),
|
User: UserFromServiceShallow(l.User),
|
||||||
APIKey: APIKeyFromService(l.APIKey),
|
APIKey: APIKeyFromService(l.APIKey),
|
||||||
|
|||||||
@@ -156,6 +156,11 @@ type Account struct {
|
|||||||
// 从 extra 字段提取,方便前端显示和编辑
|
// 从 extra 字段提取,方便前端显示和编辑
|
||||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||||
|
|
||||||
|
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||||
|
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
|
||||||
|
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||||
|
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
|
|||||||
CountryCode string `json:"country_code,omitempty"`
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
Region string `json:"region,omitempty"`
|
Region string `json:"region,omitempty"`
|
||||||
City string `json:"city,omitempty"`
|
City string `json:"city,omitempty"`
|
||||||
|
QualityStatus string `json:"quality_status,omitempty"`
|
||||||
|
QualityScore *int `json:"quality_score,omitempty"`
|
||||||
|
QualityGrade string `json:"quality_grade,omitempty"`
|
||||||
|
QualitySummary string `json:"quality_summary,omitempty"`
|
||||||
|
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyAccountSummary struct {
|
type ProxyAccountSummary struct {
|
||||||
@@ -280,6 +290,9 @@ type UsageLog struct {
|
|||||||
// User-Agent
|
// User-Agent
|
||||||
UserAgent *string `json:"user_agent"`
|
UserAgent *string `json:"user_agent"`
|
||||||
|
|
||||||
|
// Cache TTL Override 标记
|
||||||
|
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
User *User `json:"user,omitempty"`
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
|
|||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
streamMode string
|
streamMode string
|
||||||
|
soraTLSEnabled bool
|
||||||
soraMediaSigningKey string
|
soraMediaSigningKey string
|
||||||
soraMediaRoot string
|
soraMediaRoot string
|
||||||
}
|
}
|
||||||
@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
|
|||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
maxAccountSwitches := 3
|
maxAccountSwitches := 3
|
||||||
streamMode := "force"
|
streamMode := "force"
|
||||||
|
soraTLSEnabled := true
|
||||||
signKey := ""
|
signKey := ""
|
||||||
mediaRoot := "/app/data/sora"
|
mediaRoot := "/app/data/sora"
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
|
|||||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||||
streamMode = mode
|
streamMode = mode
|
||||||
}
|
}
|
||||||
|
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||||
mediaRoot = root
|
mediaRoot = root
|
||||||
@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
|
|||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
streamMode: strings.ToLower(streamMode),
|
streamMode: strings.ToLower(streamMode),
|
||||||
|
soraTLSEnabled: soraTLSEnabled,
|
||||||
soraMediaSigningKey: signKey,
|
soraMediaSigningKey: signKey,
|
||||||
soraMediaRoot: mediaRoot,
|
soraMediaRoot: mediaRoot,
|
||||||
}
|
}
|
||||||
@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
var lastFailoverBody []byte
|
||||||
|
var lastFailoverHeaders http.Header
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||||
@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||||
|
}
|
||||||
|
if rayID != "" {
|
||||||
|
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||||
|
}
|
||||||
|
if mitigated != "" {
|
||||||
|
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||||
|
}
|
||||||
|
if contentType != "" {
|
||||||
|
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||||
|
}
|
||||||
|
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
proxyBound := account.ProxyID != nil
|
||||||
|
proxyID := int64(0)
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
proxyID = *account.ProxyID
|
||||||
|
}
|
||||||
|
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||||
|
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
accountWaitCounted := false
|
accountWaitCounted := false
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
} else if !canWait {
|
} else if !canWait {
|
||||||
reqLog.Info("sora.account_wait_queue_full",
|
reqLog.Info("sora.account_wait_queue_full",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||||
)
|
)
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
&streamStarted,
|
&streamStarted,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||||
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
|
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
}
|
||||||
|
if rayID != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||||
|
}
|
||||||
|
if mitigated != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||||
|
}
|
||||||
|
if contentType != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||||
|
}
|
||||||
|
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||||
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
switchCount++
|
switchCount++
|
||||||
reqLog.Warn("sora.upstream_failover_switching",
|
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||||
|
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||||
|
fields := []zap.Field{
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.String("upstream_error_code", upstreamErrCode),
|
||||||
|
zap.String("upstream_error_message", upstreamErrMsg),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", switchCount),
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
)
|
}
|
||||||
|
if rayID != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||||
|
}
|
||||||
|
if mitigated != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||||
|
}
|
||||||
|
if contentType != "" {
|
||||||
|
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||||
|
}
|
||||||
|
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("sora.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
}(result, account, userAgent, clientIP)
|
}(result, account, userAgent, clientIP)
|
||||||
reqLog.Debug("sora.request_completed",
|
reqLog.Debug("sora.request_completed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int64("proxy_id", proxyID),
|
||||||
|
zap.Bool("proxy_bound", proxyBound),
|
||||||
|
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", switchCount),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||||
|
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||||
|
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||||
|
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||||
|
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||||
|
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||||
|
}
|
||||||
|
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||||
|
switch statusCode {
|
||||||
|
case 401, 403, 404, 500, 502, 503, 504:
|
||||||
|
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||||
|
case 429:
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401:
|
case 401:
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||||
case 403:
|
case 403:
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||||
|
case 404:
|
||||||
|
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||||
|
}
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||||
case 429:
|
case 429:
|
||||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||||
case 529:
|
case 529:
|
||||||
@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||||
|
if headers == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return headers.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||||
|
if headers != nil {
|
||||||
|
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||||
|
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||||
|
return rayID, mitigated, contentType
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||||
|
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||||
|
message = strings.TrimSpace(message)
|
||||||
|
if message == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||||
|
lower := strings.ToLower(message)
|
||||||
|
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||||
|
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||||
|
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
if streamStarted {
|
if streamStarted {
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if ok {
|
if ok {
|
||||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorData := map[string]any{
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(errorData)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
|
|||||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||||
return "task-video", nil
|
return "task-video", nil
|
||||||
}
|
}
|
||||||
|
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||||
|
return "task-video", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||||
|
return "cameo-1", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||||
|
return &service.SoraCameoStatus{
|
||||||
|
Status: "finalized",
|
||||||
|
StatusMessage: "Completed",
|
||||||
|
DisplayNameHint: "Character",
|
||||||
|
UsernameHint: "user.character",
|
||||||
|
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||||
|
return []byte("avatar"), nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||||
|
return "asset-pointer", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||||
|
return "character-1", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||||
|
return "s_post", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||||
|
return "https://example.com/no-watermark.mp4", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||||
|
return "enhanced prompt", nil
|
||||||
|
}
|
||||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||||
}
|
}
|
||||||
@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
|
|||||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash3)
|
require.NotEmpty(t, hash3)
|
||||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
errType string
|
||||||
|
message string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "包含双引号",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: `upstream returned "invalid" payload`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含换行和制表符",
|
||||||
|
errType: "rate_limit_error",
|
||||||
|
message: "line1\nline2\ttab",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含反斜杠",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: `path C:\Users\test\file.txt not found`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||||
|
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||||
|
require.Equal(t, "event: error", lines[0])
|
||||||
|
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||||
|
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||||
|
require.Equal(t, tt.errType, errorObj["type"])
|
||||||
|
require.Equal(t, tt.message, errorObj["message"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||||
|
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||||
|
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2)
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||||
|
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2)
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
msg, _ := errorObj["message"].(string)
|
||||||
|
require.Contains(t, msg, "Cloudflare challenge")
|
||||||
|
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||||
|
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2)
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||||
|
msg, _ := errorObj["message"].(string)
|
||||||
|
require.Contains(t, msg, "Cloudflare shield")
|
||||||
|
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("cf-mitigated", "challenge")
|
||||||
|
headers.Set("content-type", "text/html")
|
||||||
|
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||||
|
|
||||||
|
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||||
|
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||||
|
require.Equal(t, "challenge", mitigated)
|
||||||
|
require.Equal(t, "text/html", contentType)
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ const (
|
|||||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||||
BetaTokenCounting = "token-counting-2024-11-01"
|
BetaTokenCounting = "token-counting-2024-11-01"
|
||||||
|
BetaContext1M = "context-1m-2025-08-07"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
@@ -77,6 +78,12 @@ var DefaultModels = []Model{
|
|||||||
DisplayName: "Claude Opus 4.6",
|
DisplayName: "Claude Opus 4.6",
|
||||||
CreatedAt: "2026-02-06T00:00:00Z",
|
CreatedAt: "2026-02-06T00:00:00Z",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-6",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: "Claude Sonnet 4.6",
|
||||||
|
CreatedAt: "2026-02-18T00:00:00Z",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4-5-20250929",
|
ID: "claude-sonnet-4-5-20250929",
|
||||||
Type: "model",
|
Type: "model",
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||||
|
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||||
|
|
||||||
// OAuth endpoints
|
// OAuth endpoints
|
||||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||||
|
|||||||
@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
q := r.client.Account.Query()
|
q := r.client.Account.Query()
|
||||||
|
|
||||||
if platform != "" {
|
if platform != "" {
|
||||||
@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
|
|
||||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(accounts, tt.wantCount)
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
|||||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||||
|
|
||||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
s.Require().Len(accounts, 1)
|
s.Require().Len(accounts, 1)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
if strings.TrimSpace(clientID) != "" {
|
||||||
|
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientIDs := []string{
|
||||||
|
openai.ClientID,
|
||||||
|
openai.SoraClientID,
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{}, len(clientIDs))
|
||||||
|
var lastErr error
|
||||||
|
for _, clientID := range clientIDs {
|
||||||
|
clientID = strings.TrimSpace(clientID)
|
||||||
|
if clientID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[clientID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[clientID] = struct{}{}
|
||||||
|
|
||||||
|
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
|
if err == nil {
|
||||||
|
return tokenResp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||||
client := createOpenAIReqClient(proxyURL)
|
client := createOpenAIReqClient(proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
formData := url.Values{}
|
||||||
formData.Set("grant_type", "refresh_token")
|
formData.Set("grant_type", "refresh_token")
|
||||||
formData.Set("refresh_token", refreshToken)
|
formData.Set("refresh_token", refreshToken)
|
||||||
formData.Set("client_id", openai.ClientID)
|
formData.Set("client_id", clientID)
|
||||||
formData.Set("scope", openai.RefreshScopes)
|
formData.Set("scope", openai.RefreshScopes)
|
||||||
|
|
||||||
var tokenResp openai.TokenResponse
|
var tokenResp openai.TokenResponse
|
||||||
|
|||||||
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
|||||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||||
|
var seenClientIDs []string
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID := r.PostForm.Get("client_id")
|
||||||
|
seenClientIDs = append(seenClientIDs, clientID)
|
||||||
|
if clientID == openai.ClientID {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = io.WriteString(w, "invalid_grant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if clientID == openai.SoraClientID {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||||
|
require.NoError(s.T(), err, "RefreshToken")
|
||||||
|
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||||
|
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||||
|
const customClientID = "custom-client-id"
|
||||||
|
var seenClientIDs []string
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID := r.PostForm.Get("client_id")
|
||||||
|
seenClientIDs = append(seenClientIDs, clientID)
|
||||||
|
if clientID != customClientID {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||||
|
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||||
|
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||||
|
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||||
|
|
||||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||||
var dateFormatWhitelist = map[string]string{
|
var dateFormatWhitelist = map[string]string{
|
||||||
@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
image_size,
|
image_size,
|
||||||
media_type,
|
media_type,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5,
|
$1, $2, $3, $4, $5,
|
||||||
@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
imageSize,
|
imageSize,
|
||||||
mediaType,
|
mediaType,
|
||||||
reasoningEffort,
|
reasoningEffort,
|
||||||
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
}
|
}
|
||||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||||
@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
imageSize sql.NullString
|
imageSize sql.NullString
|
||||||
mediaType sql.NullString
|
mediaType sql.NullString
|
||||||
reasoningEffort sql.NullString
|
reasoningEffort sql.NullString
|
||||||
|
cacheTTLOverridden bool
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&imageSize,
|
&imageSize,
|
||||||
&mediaType,
|
&mediaType,
|
||||||
&reasoningEffort,
|
&reasoningEffort,
|
||||||
|
&cacheTTLOverridden,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
BillingType: int8(billingType),
|
BillingType: int8(billingType),
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
ImageCount: imageCount,
|
ImageCount: imageCount,
|
||||||
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
CreatedAt: createdAt,
|
CreatedAt: createdAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"image_count": 0,
|
"image_count": 0,
|
||||||
"image_size": null,
|
"image_size": null,
|
||||||
"media_type": null,
|
"media_type": null,
|
||||||
|
"cache_ttl_overridden": false,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"user_agent": null
|
"user_agent": null
|
||||||
}
|
}
|
||||||
@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
allowedSet[origin] = struct{}{}
|
allowedSet[origin] = struct{}{}
|
||||||
}
|
}
|
||||||
|
allowHeaders := []string{
|
||||||
|
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
|
||||||
|
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
|
||||||
|
}
|
||||||
|
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
|
||||||
|
openAIProperties := []string{
|
||||||
|
"lang", "package-version", "os", "arch", "retry-count", "runtime",
|
||||||
|
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
|
||||||
|
}
|
||||||
|
for _, prop := range openAIProperties {
|
||||||
|
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
|
||||||
|
}
|
||||||
|
allowHeadersValue := strings.Join(allowHeaders, ", ")
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||||
@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
|||||||
if allowCredentials {
|
if allowCredentials {
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||||
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
|
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
|
||||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理预检请求
|
// 处理预检请求
|
||||||
if c.Request.Method == http.MethodOptions {
|
if c.Request.Method == http.MethodOptions {
|
||||||
if originAllowed {
|
if originAllowed {
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// OpenAI OAuth
|
// OpenAI OAuth
|
||||||
registerOpenAIOAuthRoutes(admin, h)
|
registerOpenAIOAuthRoutes(admin, h)
|
||||||
|
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
||||||
|
registerSoraOAuthRoutes(admin, h)
|
||||||
|
|
||||||
// Gemini OAuth
|
// Gemini OAuth
|
||||||
registerGeminiOAuthRoutes(admin, h)
|
registerGeminiOAuthRoutes(admin, h)
|
||||||
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
sora := admin.Group("/sora")
|
||||||
|
{
|
||||||
|
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||||
|
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||||
|
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||||
|
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
||||||
|
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
||||||
|
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||||
|
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
gemini := admin.Group("/gemini")
|
gemini := admin.Group("/gemini")
|
||||||
{
|
{
|
||||||
@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||||
|
proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
|
||||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||||
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package routes
|
package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
|
|||||||
gateway.GET("/usage", h.Gateway.Usage)
|
gateway.GET("/usage", h.Gateway.Usage)
|
||||||
// OpenAI Responses API
|
// OpenAI Responses API
|
||||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||||
}
|
// 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
|
||||||
|
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||||
// Sora Chat Completions
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
soraGateway := r.Group("/v1")
|
"error": gin.H{
|
||||||
soraGateway.Use(soraBodyLimit)
|
"type": "invalid_request_error",
|
||||||
soraGateway.Use(clientRequestID)
|
"message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
|
||||||
soraGateway.Use(opsErrorLogger)
|
},
|
||||||
soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
|
})
|
||||||
{
|
})
|
||||||
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
|
|||||||
@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
|
||||||
|
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||||
|
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
|
||||||
|
func (a *Account) IsCacheTTLOverrideEnabled() bool {
|
||||||
|
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Extra == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
|
||||||
|
if enabled, ok := v.(bool); ok {
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
|
||||||
|
// 返回 "5m" 或 "1h",默认 "5m"
|
||||||
|
func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return "5m"
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["cache_ttl_override_target"]; ok {
|
||||||
|
if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "5m"
|
||||||
|
}
|
||||||
|
|
||||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||||
// 返回 0 表示未启用
|
// 返回 0 表示未启用
|
||||||
func (a *Account) GetWindowCostLimit() float64 {
|
func (a *Account) GetWindowCostLimit() float64 {
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ type AccountRepository interface {
|
|||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||||
ListActive(ctx context.Context) ([]Account, error)
|
ListActive(ctx context.Context) ([]Account, error)
|
||||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
|||||||
panic("unexpected List call")
|
panic("unexpected List call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListWithFilters call")
|
panic("unexpected ListWithFilters call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,13 +12,17 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -32,6 +36,10 @@ const (
|
|||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||||
|
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||||
|
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
|
||||||
|
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
|
||||||
|
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
@@ -39,6 +47,9 @@ type TestEvent struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Data any `json:"data,omitempty"`
|
||||||
Success bool `json:"success,omitempty"`
|
Success bool `json:"success,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -50,8 +61,13 @@ type AccountTestService struct {
|
|||||||
antigravityGatewayService *AntigravityGatewayService
|
antigravityGatewayService *AntigravityGatewayService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
soraTestGuardMu sync.Mutex
|
||||||
|
soraTestLastRun map[int64]time.Time
|
||||||
|
soraTestCooldown time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const defaultSoraTestCooldown = 10 * time.Second
|
||||||
|
|
||||||
// NewAccountTestService creates a new AccountTestService
|
// NewAccountTestService creates a new AccountTestService
|
||||||
func NewAccountTestService(
|
func NewAccountTestService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
@@ -66,6 +82,8 @@ func NewAccountTestService(
|
|||||||
antigravityGatewayService: antigravityGatewayService,
|
antigravityGatewayService: antigravityGatewayService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
soraTestLastRun: make(map[int64]time.Time),
|
||||||
|
soraTestCooldown: defaultSoraTestCooldown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
return s.processGeminiStream(c, resp.Body)
|
return s.processGeminiStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type soraProbeStep struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
HTTPStatus int `json:"http_status,omitempty"`
|
||||||
|
ErrorCode string `json:"error_code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraProbeSummary struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Steps []soraProbeStep `json:"steps"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraProbeRecorder struct {
|
||||||
|
steps []soraProbeStep
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
|
||||||
|
r.steps = append(r.steps, soraProbeStep{
|
||||||
|
Name: name,
|
||||||
|
Status: status,
|
||||||
|
HTTPStatus: httpStatus,
|
||||||
|
ErrorCode: strings.TrimSpace(errorCode),
|
||||||
|
Message: strings.TrimSpace(message),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *soraProbeRecorder) finalize() soraProbeSummary {
|
||||||
|
meSuccess := false
|
||||||
|
partial := false
|
||||||
|
for _, step := range r.steps {
|
||||||
|
if step.Name == "me" {
|
||||||
|
meSuccess = strings.EqualFold(step.Status, "success")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(step.Status, "failed") {
|
||||||
|
partial = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "success"
|
||||||
|
if !meSuccess {
|
||||||
|
status = "failed"
|
||||||
|
} else if partial {
|
||||||
|
status = "partial_success"
|
||||||
|
}
|
||||||
|
|
||||||
|
return soraProbeSummary{
|
||||||
|
Status: status,
|
||||||
|
Steps: append([]soraProbeStep(nil), r.steps...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
|
||||||
|
if rec == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary := rec.finalize()
|
||||||
|
code := ""
|
||||||
|
for _, step := range summary.Steps {
|
||||||
|
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
|
||||||
|
code = step.ErrorCode
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{
|
||||||
|
Type: "sora_test_result",
|
||||||
|
Status: summary.Status,
|
||||||
|
Code: code,
|
||||||
|
Data: summary,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
|
||||||
|
if accountID <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
s.soraTestGuardMu.Lock()
|
||||||
|
defer s.soraTestGuardMu.Unlock()
|
||||||
|
|
||||||
|
if s.soraTestLastRun == nil {
|
||||||
|
s.soraTestLastRun = make(map[int64]time.Time)
|
||||||
|
}
|
||||||
|
cooldown := s.soraTestCooldown
|
||||||
|
if cooldown <= 0 {
|
||||||
|
cooldown = defaultSoraTestCooldown
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
|
||||||
|
elapsed := now.Sub(lastRun)
|
||||||
|
if elapsed < cooldown {
|
||||||
|
return cooldown - elapsed, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.soraTestLastRun[accountID] = now
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ceilSeconds(d time.Duration) int {
|
||||||
|
if d <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
sec := int(d / time.Second)
|
||||||
|
if d%time.Second != 0 {
|
||||||
|
sec++
|
||||||
|
}
|
||||||
|
if sec < 1 {
|
||||||
|
sec = 1
|
||||||
|
}
|
||||||
|
return sec
|
||||||
|
}
|
||||||
|
|
||||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||||
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
||||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
recorder := &soraProbeRecorder{}
|
||||||
|
|
||||||
authToken := account.GetCredential("access_token")
|
authToken := account.GetCredential("access_token")
|
||||||
if authToken == "" {
|
if authToken == "" {
|
||||||
|
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
return s.sendErrorAndEnd(c, "No access token available")
|
return s.sendErrorAndEnd(c, "No access token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||||
|
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||||
|
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
return s.sendErrorAndEnd(c, msg)
|
||||||
|
}
|
||||||
|
|
||||||
// Send test_start event
|
// Send test_start event
|
||||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,15 +639,21 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||||
|
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
// Get proxy URL
|
// Get proxy URL
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
|
||||||
|
|
||||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
recorder.addStep("me", "failed", 0, "network_error", err.Error())
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
@@ -512,8 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
|
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||||
|
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
||||||
|
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
|
||||||
|
}
|
||||||
|
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
|
||||||
|
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
|
||||||
|
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
|
||||||
|
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
|
||||||
|
case strings.TrimSpace(upstreamMessage) != "":
|
||||||
|
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
|
||||||
|
default:
|
||||||
|
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
|
||||||
|
|
||||||
// 解析 /me 响应,提取用户信息
|
// 解析 /me 响应,提取用户信息
|
||||||
var meResp map[string]any
|
var meResp map[string]any
|
||||||
@@ -531,10 +705,384 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
||||||
|
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
||||||
|
if err == nil {
|
||||||
|
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
subReq.Header.Set("Accept", "application/json")
|
||||||
|
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||||
|
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
|
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||||
|
if subErr != nil {
|
||||||
|
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||||
|
} else {
|
||||||
|
subBody, _ := io.ReadAll(subResp.Body)
|
||||||
|
_ = subResp.Body.Close()
|
||||||
|
if subResp.StatusCode == http.StatusOK {
|
||||||
|
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
|
||||||
|
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||||
|
} else {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
|
||||||
|
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
|
||||||
|
} else {
|
||||||
|
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
|
||||||
|
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
|
||||||
|
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
|
||||||
|
|
||||||
|
s.emitSoraProbeSummary(c, recorder)
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) testSora2Capabilities(
|
||||||
|
c *gin.Context,
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
authToken string,
|
||||||
|
proxyURL string,
|
||||||
|
enableTLSFingerprint bool,
|
||||||
|
recorder *soraProbeRecorder,
|
||||||
|
) {
|
||||||
|
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
|
||||||
|
ctx,
|
||||||
|
account,
|
||||||
|
authToken,
|
||||||
|
soraInviteMineURL,
|
||||||
|
proxyURL,
|
||||||
|
enableTLSFingerprint,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if inviteStatus == http.StatusUnauthorized {
|
||||||
|
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
|
||||||
|
ctx,
|
||||||
|
account,
|
||||||
|
authToken,
|
||||||
|
soraBootstrapURL,
|
||||||
|
proxyURL,
|
||||||
|
enableTLSFingerprint,
|
||||||
|
)
|
||||||
|
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
|
||||||
|
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
|
||||||
|
ctx,
|
||||||
|
account,
|
||||||
|
authToken,
|
||||||
|
soraInviteMineURL,
|
||||||
|
proxyURL,
|
||||||
|
enableTLSFingerprint,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if recorder != nil {
|
||||||
|
code := ""
|
||||||
|
msg := ""
|
||||||
|
if bootstrapErr != nil {
|
||||||
|
code = "network_error"
|
||||||
|
msg = bootstrapErr.Error()
|
||||||
|
}
|
||||||
|
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inviteStatus != http.StatusOK {
|
||||||
|
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||||
|
}
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||||
|
} else {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
|
||||||
|
}
|
||||||
|
|
||||||
|
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
|
||||||
|
ctx,
|
||||||
|
account,
|
||||||
|
authToken,
|
||||||
|
soraRemainingURL,
|
||||||
|
proxyURL,
|
||||||
|
enableTLSFingerprint,
|
||||||
|
)
|
||||||
|
if remainingErr != nil {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if remainingStatus != http.StatusOK {
|
||||||
|
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||||
|
}
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
|
||||||
|
}
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
|
||||||
|
}
|
||||||
|
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||||
|
} else {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) fetchSoraTestEndpoint(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
authToken string,
|
||||||
|
url string,
|
||||||
|
proxyURL string,
|
||||||
|
enableTLSFingerprint bool,
|
||||||
|
) (int, http.Header, []byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||||
|
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, readErr := io.ReadAll(resp.Body)
|
||||||
|
if readErr != nil {
|
||||||
|
return resp.StatusCode, resp.Header, nil, readErr
|
||||||
|
}
|
||||||
|
return resp.StatusCode, resp.Header, body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSoraSubscriptionSummary(body []byte) string {
|
||||||
|
var subResp struct {
|
||||||
|
Data []struct {
|
||||||
|
Plan struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
} `json:"plan"`
|
||||||
|
EndTS string `json:"end_ts"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &subResp); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(subResp.Data) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
first := subResp.Data[0]
|
||||||
|
parts := make([]string, 0, 3)
|
||||||
|
if first.Plan.Title != "" {
|
||||||
|
parts = append(parts, first.Plan.Title)
|
||||||
|
}
|
||||||
|
if first.Plan.ID != "" {
|
||||||
|
parts = append(parts, first.Plan.ID)
|
||||||
|
}
|
||||||
|
if first.EndTS != "" {
|
||||||
|
parts = append(parts, "end="+first.EndTS)
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "Subscription: " + strings.Join(parts, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSoraInviteSummary(body []byte) string {
|
||||||
|
var inviteResp struct {
|
||||||
|
InviteCode string `json:"invite_code"`
|
||||||
|
RedeemedCount int64 `json:"redeemed_count"`
|
||||||
|
TotalCount int64 `json:"total_count"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &inviteResp); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := []string{"Sora2: supported"}
|
||||||
|
if inviteResp.InviteCode != "" {
|
||||||
|
parts = append(parts, "invite="+inviteResp.InviteCode)
|
||||||
|
}
|
||||||
|
if inviteResp.TotalCount > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSoraRemainingSummary(body []byte) string {
|
||||||
|
var remainingResp struct {
|
||||||
|
RateLimitAndCreditBalance struct {
|
||||||
|
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
|
||||||
|
RateLimitReached bool `json:"rate_limit_reached"`
|
||||||
|
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
|
||||||
|
} `json:"rate_limit_and_credit_balance"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &remainingResp); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
info := remainingResp.RateLimitAndCreditBalance
|
||||||
|
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
|
||||||
|
if info.RateLimitReached {
|
||||||
|
parts = append(parts, "rate_limited=true")
|
||||||
|
}
|
||||||
|
if info.AccessResetsInSeconds > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !s.cfg.Sora.Client.DisableTLSFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||||
|
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||||
|
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||||
|
return soraerror.ExtractCloudflareRayID(headers, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraEgressIPHint(headers http.Header) string {
|
||||||
|
if headers == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
candidates := []string{
|
||||||
|
"x-openai-public-ip",
|
||||||
|
"x-envoy-external-address",
|
||||||
|
"cf-connecting-ip",
|
||||||
|
"x-forwarded-for",
|
||||||
|
}
|
||||||
|
for _, key := range candidates {
|
||||||
|
if value := strings.TrimSpace(headers.Get(key)); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeProxyURLForLog(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return "<invalid_proxy_url>"
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
u.User = nil
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func endpointPathForLog(endpoint string) string {
|
||||||
|
parsed, err := url.Parse(strings.TrimSpace(endpoint))
|
||||||
|
if err != nil || parsed.Path == "" {
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
return parsed.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
|
||||||
|
accountID := int64(0)
|
||||||
|
platform := ""
|
||||||
|
proxyID := "none"
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
platform = account.Platform
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
proxyID = fmt.Sprintf("%d", *account.ProxyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfRay := extractCloudflareRayID(headers, body)
|
||||||
|
if cfRay == "" {
|
||||||
|
cfRay = "unknown"
|
||||||
|
}
|
||||||
|
log.Printf(
|
||||||
|
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
|
||||||
|
accountID,
|
||||||
|
platform,
|
||||||
|
endpoint,
|
||||||
|
endpointPathForLog(endpoint),
|
||||||
|
proxyID,
|
||||||
|
sanitizeProxyURLForLog(proxyURL),
|
||||||
|
cfRay,
|
||||||
|
extractSoraEgressIPHint(headers),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateSoraErrorBody(body []byte, max int) string {
|
||||||
|
return soraerror.TruncateBody(body, max)
|
||||||
|
}
|
||||||
|
|
||||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||||
|
|||||||
319
backend/internal/service/account_test_service_sora_test.go
Normal file
319
backend/internal/service/account_test_service_sora_test.go
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type queuedHTTPUpstream struct {
|
||||||
|
responses []*http.Response
|
||||||
|
requests []*http.Request
|
||||||
|
tlsFlags []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return nil, fmt.Errorf("unexpected Do call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||||
|
u.requests = append(u.requests, req)
|
||||||
|
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
|
||||||
|
if len(u.responses) == 0 {
|
||||||
|
return nil, fmt.Errorf("no mocked response")
|
||||||
|
}
|
||||||
|
resp := u.responses[0]
|
||||||
|
u.responses = u.responses[1:]
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONResponse(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||||
|
resp := newJSONResponse(status, body)
|
||||||
|
resp.Header.Set(key, value)
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||||
|
return c, rec
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||||
|
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||||
|
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
|
||||||
|
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
httpUpstream: upstream,
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
TLSFingerprint: config.TLSFingerprintConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
DisableTLSFingerprint: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upstream.requests, 4)
|
||||||
|
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||||
|
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||||
|
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
|
||||||
|
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
|
||||||
|
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||||
|
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"test_start"`)
|
||||||
|
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||||
|
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||||
|
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
||||||
|
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
||||||
|
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||||
|
require.Contains(t, body, `"status":"success"`)
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upstream.requests, 4)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||||
|
require.Contains(t, body, "Subscription check returned 403")
|
||||||
|
require.Contains(t, body, "Sora2 invite check returned 401")
|
||||||
|
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||||
|
require.Contains(t, body, `"status":"partial_success"`)
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||||
|
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
require.Contains(t, body, "Cloudflare challenge")
|
||||||
|
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||||
|
require.Contains(t, err.Error(), "HTTP 429")
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Cloudflare challenge")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "token_invalidated")
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||||
|
require.Contains(t, body, `"status":"failed"`)
|
||||||
|
require.Contains(t, body, "token_invalidated")
|
||||||
|
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
httpUpstream: upstream,
|
||||||
|
soraTestCooldown: time.Hour,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c1, _ := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c1, account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c2, rec2 := newSoraTestContext()
|
||||||
|
err = svc.testSoraAccountConnection(c2, account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "测试过于频繁")
|
||||||
|
body := rec2.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||||
|
require.Contains(t, body, `"code":"test_rate_limited"`)
|
||||||
|
require.Contains(t, body, `"status":"failed"`)
|
||||||
|
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||||
|
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
|
||||||
|
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeProxyURLForLog(t *testing.T) {
|
||||||
|
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
||||||
|
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
||||||
|
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSoraEgressIPHint(t *testing.T) {
|
||||||
|
h := make(http.Header)
|
||||||
|
h.Set("x-openai-public-ip", "203.0.113.10")
|
||||||
|
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
||||||
|
|
||||||
|
h2 := make(http.Header)
|
||||||
|
h2.Set("x-envoy-external-address", "198.51.100.9")
|
||||||
|
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
||||||
|
|
||||||
|
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
||||||
|
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
||||||
|
}
|
||||||
@@ -4,11 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminService interface defines admin management operations
|
// AdminService interface defines admin management operations
|
||||||
@@ -39,7 +43,7 @@ type AdminService interface {
|
|||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// Account management
|
// Account management
|
||||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||||
@@ -65,6 +69,7 @@ type AdminService interface {
|
|||||||
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||||
|
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
|
||||||
|
|
||||||
// Redeem code management
|
// Redeem code management
|
||||||
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
|
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
|
||||||
@@ -288,6 +293,32 @@ type ProxyTestResult struct {
|
|||||||
CountryCode string `json:"country_code,omitempty"`
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProxyQualityCheckResult struct {
|
||||||
|
ProxyID int64 `json:"proxy_id"`
|
||||||
|
Score int `json:"score"`
|
||||||
|
Grade string `json:"grade"`
|
||||||
|
Summary string `json:"summary"`
|
||||||
|
ExitIP string `json:"exit_ip,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
|
BaseLatencyMs int64 `json:"base_latency_ms,omitempty"`
|
||||||
|
PassedCount int `json:"passed_count"`
|
||||||
|
WarnCount int `json:"warn_count"`
|
||||||
|
FailedCount int `json:"failed_count"`
|
||||||
|
ChallengeCount int `json:"challenge_count"`
|
||||||
|
CheckedAt int64 `json:"checked_at"`
|
||||||
|
Items []ProxyQualityCheckItem `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyQualityCheckItem struct {
|
||||||
|
Target string `json:"target"`
|
||||||
|
Status string `json:"status"` // pass/warn/fail/challenge
|
||||||
|
HTTPStatus int `json:"http_status,omitempty"`
|
||||||
|
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
CFRay string `json:"cf_ray,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyExitInfo represents proxy exit information from ip-api.com
|
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||||
type ProxyExitInfo struct {
|
type ProxyExitInfo struct {
|
||||||
IP string
|
IP string
|
||||||
@@ -302,6 +333,58 @@ type ProxyExitInfoProber interface {
|
|||||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type proxyQualityTarget struct {
|
||||||
|
Target string
|
||||||
|
URL string
|
||||||
|
Method string
|
||||||
|
AllowedStatuses map[int]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyQualityTargets = []proxyQualityTarget{
|
||||||
|
{
|
||||||
|
Target: "openai",
|
||||||
|
URL: "https://api.openai.com/v1/models",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusUnauthorized: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Target: "anthropic",
|
||||||
|
URL: "https://api.anthropic.com/v1/messages",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusUnauthorized: {},
|
||||||
|
http.StatusMethodNotAllowed: {},
|
||||||
|
http.StatusNotFound: {},
|
||||||
|
http.StatusBadRequest: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Target: "gemini",
|
||||||
|
URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusOK: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Target: "sora",
|
||||||
|
URL: "https://sora.chatgpt.com/backend/me",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusUnauthorized: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
proxyQualityRequestTimeout = 15 * time.Second
|
||||||
|
proxyQualityResponseHeaderTimeout = 10 * time.Second
|
||||||
|
proxyQualityMaxBodyBytes = int64(8 * 1024)
|
||||||
|
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
@@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Account management implementations
|
// Account management implementations
|
||||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) {
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &ProxyQualityCheckResult{
|
||||||
|
ProxyID: id,
|
||||||
|
Score: 100,
|
||||||
|
Grade: "A",
|
||||||
|
CheckedAt: time.Now().Unix(),
|
||||||
|
Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1),
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := proxy.URL()
|
||||||
|
if s.proxyProber == nil {
|
||||||
|
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||||
|
Target: "base_connectivity",
|
||||||
|
Status: "fail",
|
||||||
|
Message: "代理探测服务未配置",
|
||||||
|
})
|
||||||
|
result.FailedCount++
|
||||||
|
finalizeProxyQualityResult(result)
|
||||||
|
s.saveProxyQualitySnapshot(ctx, id, result, nil)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||||
|
Target: "base_connectivity",
|
||||||
|
Status: "fail",
|
||||||
|
LatencyMs: latencyMs,
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
result.FailedCount++
|
||||||
|
finalizeProxyQualityResult(result)
|
||||||
|
s.saveProxyQualitySnapshot(ctx, id, result, nil)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result.ExitIP = exitInfo.IP
|
||||||
|
result.Country = exitInfo.Country
|
||||||
|
result.CountryCode = exitInfo.CountryCode
|
||||||
|
result.BaseLatencyMs = latencyMs
|
||||||
|
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||||
|
Target: "base_connectivity",
|
||||||
|
Status: "pass",
|
||||||
|
LatencyMs: latencyMs,
|
||||||
|
Message: "代理出口连通正常",
|
||||||
|
})
|
||||||
|
result.PassedCount++
|
||||||
|
|
||||||
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Timeout: proxyQualityRequestTimeout,
|
||||||
|
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
|
||||||
|
ProxyStrict: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||||
|
Target: "http_client",
|
||||||
|
Status: "fail",
|
||||||
|
Message: fmt.Sprintf("创建检测客户端失败: %v", err),
|
||||||
|
})
|
||||||
|
result.FailedCount++
|
||||||
|
finalizeProxyQualityResult(result)
|
||||||
|
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range proxyQualityTargets {
|
||||||
|
item := runProxyQualityTarget(ctx, client, target)
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
switch item.Status {
|
||||||
|
case "pass":
|
||||||
|
result.PassedCount++
|
||||||
|
case "warn":
|
||||||
|
result.WarnCount++
|
||||||
|
case "challenge":
|
||||||
|
result.ChallengeCount++
|
||||||
|
default:
|
||||||
|
result.FailedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finalizeProxyQualityResult(result)
|
||||||
|
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem {
|
||||||
|
item := ProxyQualityCheckItem{
|
||||||
|
Target: target.Target,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
item.Status = "fail"
|
||||||
|
item.Message = fmt.Sprintf("构建请求失败: %v", err)
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json,text/html,*/*")
|
||||||
|
req.Header.Set("User-Agent", proxyQualityClientUserAgent)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
item.Status = "fail"
|
||||||
|
item.LatencyMs = time.Since(start).Milliseconds()
|
||||||
|
item.Message = fmt.Sprintf("请求失败: %v", err)
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
item.LatencyMs = time.Since(start).Milliseconds()
|
||||||
|
item.HTTPStatus = resp.StatusCode
|
||||||
|
|
||||||
|
body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1))
|
||||||
|
if readErr != nil {
|
||||||
|
item.Status = "fail"
|
||||||
|
item.Message = fmt.Sprintf("读取响应失败: %v", readErr)
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
if int64(len(body)) > proxyQualityMaxBodyBytes {
|
||||||
|
body = body[:proxyQualityMaxBodyBytes]
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||||
|
item.Status = "challenge"
|
||||||
|
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
|
||||||
|
item.Message = "Sora 命中 Cloudflare challenge"
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := target.AllowedStatuses[resp.StatusCode]; ok {
|
||||||
|
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
|
||||||
|
item.Status = "pass"
|
||||||
|
item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||||
|
} else {
|
||||||
|
item.Status = "warn"
|
||||||
|
item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
item.Status = "warn"
|
||||||
|
item.Message = "目标返回 429,可能存在频控"
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
item.Status = "fail"
|
||||||
|
item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode)
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
func finalizeProxyQualityResult(result *ProxyQualityCheckResult) {
|
||||||
|
if result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30
|
||||||
|
if score < 0 {
|
||||||
|
score = 0
|
||||||
|
}
|
||||||
|
result.Score = score
|
||||||
|
result.Grade = proxyQualityGrade(score)
|
||||||
|
result.Summary = fmt.Sprintf(
|
||||||
|
"通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项",
|
||||||
|
result.PassedCount,
|
||||||
|
result.WarnCount,
|
||||||
|
result.FailedCount,
|
||||||
|
result.ChallengeCount,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyQualityGrade(score int) string {
|
||||||
|
switch {
|
||||||
|
case score >= 90:
|
||||||
|
return "A"
|
||||||
|
case score >= 75:
|
||||||
|
return "B"
|
||||||
|
case score >= 60:
|
||||||
|
return "C"
|
||||||
|
case score >= 40:
|
||||||
|
return "D"
|
||||||
|
default:
|
||||||
|
return "F"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string {
|
||||||
|
if result == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if result.ChallengeCount > 0 {
|
||||||
|
return "challenge"
|
||||||
|
}
|
||||||
|
if result.FailedCount > 0 {
|
||||||
|
return "failed"
|
||||||
|
}
|
||||||
|
if result.WarnCount > 0 {
|
||||||
|
return "warn"
|
||||||
|
}
|
||||||
|
if result.PassedCount > 0 {
|
||||||
|
return "healthy"
|
||||||
|
}
|
||||||
|
return "failed"
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string {
|
||||||
|
if result == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, item := range result.Items {
|
||||||
|
if item.CFRay != "" {
|
||||||
|
return item.CFRay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range result.Items {
|
||||||
|
if item.Target == "base_connectivity" {
|
||||||
|
return item.Status == "pass"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) {
|
||||||
|
if result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
score := result.Score
|
||||||
|
checkedAt := result.CheckedAt
|
||||||
|
info := &ProxyLatencyInfo{
|
||||||
|
Success: proxyQualityBaseConnectivityPass(result),
|
||||||
|
Message: result.Summary,
|
||||||
|
QualityStatus: proxyQualityOverallStatus(result),
|
||||||
|
QualityScore: &score,
|
||||||
|
QualityGrade: result.Grade,
|
||||||
|
QualitySummary: result.Summary,
|
||||||
|
QualityCheckedAt: &checkedAt,
|
||||||
|
QualityCFRay: proxyQualityFirstCFRay(result),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if result.BaseLatencyMs > 0 {
|
||||||
|
latency := result.BaseLatencyMs
|
||||||
|
info.LatencyMs = &latency
|
||||||
|
}
|
||||||
|
if exitInfo != nil {
|
||||||
|
info.IPAddress = exitInfo.IP
|
||||||
|
info.Country = exitInfo.Country
|
||||||
|
info.CountryCode = exitInfo.CountryCode
|
||||||
|
info.Region = exitInfo.Region
|
||||||
|
info.City = exitInfo.City
|
||||||
|
}
|
||||||
|
s.saveProxyLatency(ctx, proxyID, info)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||||
if s.proxyProber == nil || proxy == nil {
|
if s.proxyProber == nil || proxy == nil {
|
||||||
return
|
return
|
||||||
@@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
|
|||||||
proxies[i].CountryCode = info.CountryCode
|
proxies[i].CountryCode = info.CountryCode
|
||||||
proxies[i].Region = info.Region
|
proxies[i].Region = info.Region
|
||||||
proxies[i].City = info.City
|
proxies[i].City = info.City
|
||||||
|
proxies[i].QualityStatus = info.QualityStatus
|
||||||
|
proxies[i].QualityScore = info.QualityScore
|
||||||
|
proxies[i].QualityGrade = info.QualityGrade
|
||||||
|
proxies[i].QualitySummary = info.QualitySummary
|
||||||
|
proxies[i].QualityChecked = info.QualityCheckedAt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
|
|||||||
if s.proxyLatencyCache == nil || info == nil {
|
if s.proxyLatencyCache == nil || info == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
|
||||||
|
merged := *info
|
||||||
|
if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil {
|
||||||
|
if existing := latencies[proxyID]; existing != nil {
|
||||||
|
if merged.QualityCheckedAt == nil &&
|
||||||
|
merged.QualityScore == nil &&
|
||||||
|
merged.QualityGrade == "" &&
|
||||||
|
merged.QualityStatus == "" &&
|
||||||
|
merged.QualitySummary == "" &&
|
||||||
|
merged.QualityCFRay == "" {
|
||||||
|
merged.QualityStatus = existing.QualityStatus
|
||||||
|
merged.QualityScore = existing.QualityScore
|
||||||
|
merged.QualityGrade = existing.QualityGrade
|
||||||
|
merged.QualitySummary = existing.QualitySummary
|
||||||
|
merged.QualityCheckedAt = existing.QualityCheckedAt
|
||||||
|
merged.QualityCFRay = existing.QualityCFRay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil {
|
||||||
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
|
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
95
backend/internal/service/admin_service_proxy_quality_test.go
Normal file
95
backend/internal/service/admin_service_proxy_quality_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
|
||||||
|
result := &ProxyQualityCheckResult{
|
||||||
|
PassedCount: 2,
|
||||||
|
WarnCount: 1,
|
||||||
|
FailedCount: 1,
|
||||||
|
ChallengeCount: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
finalizeProxyQualityResult(result)
|
||||||
|
|
||||||
|
require.Equal(t, 38, result.Score)
|
||||||
|
require.Equal(t, "F", result.Grade)
|
||||||
|
require.Contains(t, result.Summary, "通过 2 项")
|
||||||
|
require.Contains(t, result.Summary, "告警 1 项")
|
||||||
|
require.Contains(t, result.Summary, "失败 1 项")
|
||||||
|
require.Contains(t, result.Summary, "挑战 1 项")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.Header().Set("cf-ray", "test-ray-123")
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
target := proxyQualityTarget{
|
||||||
|
Target: "sora",
|
||||||
|
URL: server.URL,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusUnauthorized: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||||
|
require.Equal(t, "challenge", item.Status)
|
||||||
|
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
|
||||||
|
require.Equal(t, "test-ray-123", item.CFRay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"models":[]}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
target := proxyQualityTarget{
|
||||||
|
Target: "gemini",
|
||||||
|
URL: server.URL,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusOK: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||||
|
require.Equal(t, "pass", item.Status)
|
||||||
|
require.Equal(t, http.StatusOK, item.HTTPStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
target := proxyQualityTarget{
|
||||||
|
Target: "openai",
|
||||||
|
URL: server.URL,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
http.StatusUnauthorized: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||||
|
require.Equal(t, "warn", item.Status)
|
||||||
|
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
|
||||||
|
require.Contains(t, item.Message, "目标可达")
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
|
|||||||
listWithFiltersErr error
|
listWithFiltersErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
s.listWithFiltersCalls++
|
s.listWithFiltersCalls++
|
||||||
s.listWithFiltersParams = params
|
s.listWithFiltersParams = params
|
||||||
s.listWithFiltersPlatform = platform
|
s.listWithFiltersPlatform = platform
|
||||||
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svc := &adminServiceImpl{accountRepo: repo}
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, int64(10), total)
|
require.Equal(t, int64(10), total)
|
||||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||||
|
|||||||
@@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
|
|||||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
usage.CacheCreationInputTokens = int(v)
|
usage.CacheCreationInputTokens = int(v)
|
||||||
}
|
}
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||||
|
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation5mTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation1hTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||||
@@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
|
|||||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||||
usage.CacheCreationInputTokens = int(v)
|
usage.CacheCreationInputTokens = int(v)
|
||||||
}
|
}
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||||
|
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation5mTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation1hTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ type ModelPricing struct {
|
|||||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
|||||||
if s.pricingService != nil {
|
if s.pricingService != nil {
|
||||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||||
if litellmPricing != nil {
|
if litellmPricing != nil {
|
||||||
|
// 启用 5m/1h 分类计费的条件:
|
||||||
|
// 1. 存在 1h 价格
|
||||||
|
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||||
|
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||||
|
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||||
|
enableBreakdown := price1h > 0 && price1h > price5m
|
||||||
return &ModelPricing{
|
return &ModelPricing{
|
||||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||||
SupportsCacheBreakdown: false,
|
CacheCreation5mPrice: price5m,
|
||||||
|
CacheCreation1hPrice: price1h,
|
||||||
|
SupportsCacheBreakdown: enableBreakdown,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
|||||||
|
|
||||||
// 计算缓存费用
|
// 计算缓存费用
|
||||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||||
|
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||||
|
} else {
|
||||||
|
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||||
|
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 标准缓存创建价格(per-token)
|
// 标准缓存创建价格(per-token)
|
||||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||||
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
|||||||
|
|
||||||
// 范围内部分:正常计费
|
// 范围内部分:正常计费
|
||||||
inRangeTokens := UsageTokens{
|
inRangeTokens := UsageTokens{
|
||||||
InputTokens: inRangeInputTokens,
|
InputTokens: inRangeInputTokens,
|
||||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||||
CacheReadTokens: inRangeCacheTokens,
|
CacheReadTokens: inRangeCacheTokens,
|
||||||
|
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
|
|||||||
InputPricePerToken: 3e-6,
|
InputPricePerToken: 3e-6,
|
||||||
OutputPricePerToken: 15e-6,
|
OutputPricePerToken: 15e-6,
|
||||||
SupportsCacheBreakdown: true,
|
SupportsCacheBreakdown: true,
|
||||||
CacheCreation5mPrice: 4.0, // per million tokens
|
CacheCreation5mPrice: 4e-6, // per token
|
||||||
CacheCreation1hPrice: 5.0, // per million tokens
|
CacheCreation1hPrice: 5e-6, // per token
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
|
|||||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected5m := float64(100000) / 1_000_000 * 4.0
|
expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
|
||||||
expected1h := float64(50000) / 1_000_000 * 5.0
|
expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
|
||||||
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
|
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripBetaToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
token string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "token in middle",
|
||||||
|
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token at start",
|
||||||
|
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token at end",
|
||||||
|
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token not present",
|
||||||
|
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty header",
|
||||||
|
header: "",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with spaces",
|
||||||
|
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only token",
|
||||||
|
header: "context-1m-2025-08-07",
|
||||||
|
token: "context-1m-2025-08-07",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := stripBetaToken(tt.header, tt.token)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
|
||||||
|
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
||||||
|
incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20"
|
||||||
|
drop := map[string]struct{}{"context-1m-2025-08-07": {}}
|
||||||
|
|
||||||
|
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||||
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||||
|
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||||
|
}
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
|||||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||||
|
|||||||
@@ -349,6 +349,8 @@ type ClaudeUsage struct {
|
|||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
|
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||||
|
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardResult 转发结果
|
// ForwardResult 转发结果
|
||||||
@@ -373,9 +375,10 @@ type ForwardResult struct {
|
|||||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||||
type UpstreamFailoverError struct {
|
type UpstreamFailoverError struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息
|
||||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||||
|
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *UpstreamFailoverError) Error() string {
|
func (e *UpstreamFailoverError) Error() string {
|
||||||
@@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
// messages requests typically use only oauth + interleaved-thinking.
|
// messages requests typically use only oauth + interleaved-thinking.
|
||||||
// Also drop claude-code beta if a downstream client added it.
|
// Also drop claude-code beta if a downstream client added it.
|
||||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||||
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
|
drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
|
||||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||||
} else {
|
} else {
|
||||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
|
req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
|
||||||
}
|
}
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
@@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
|
|||||||
return strings.Join(out, ",")
|
return strings.Join(out, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripBetaToken removes a single beta token from a comma-separated header value.
|
||||||
|
// It short-circuits when the token is not present to avoid unnecessary allocations.
|
||||||
|
func stripBetaToken(header, token string) string {
|
||||||
|
if !strings.Contains(header, token) {
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
out := make([]string, 0, 8)
|
||||||
|
for _, p := range strings.Split(header, ",") {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" || p == token {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
return strings.Join(out, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
||||||
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
||||||
// headers when using Claude Code-scoped OAuth credentials.
|
// headers when using Claude Code-scoped OAuth credentials.
|
||||||
@@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||||
|
if account.IsCacheTTLOverrideEnabled() {
|
||||||
|
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||||
|
if eventType == "message_start" {
|
||||||
|
if msg, ok := event["message"].(map[string]any); ok {
|
||||||
|
if u, ok := msg["usage"].(map[string]any); ok {
|
||||||
|
rewriteCacheCreationJSON(u, overrideTarget)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if eventType == "message_delta" {
|
||||||
|
if u, ok := event["usage"].(map[string]any); ok {
|
||||||
|
rewriteCacheCreationJSON(u, overrideTarget)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if needModelReplace {
|
if needModelReplace {
|
||||||
if msg, ok := event["message"].(map[string]any); ok {
|
if msg, ok := event["message"].(map[string]any); ok {
|
||||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||||
@@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() || cc1h.Exists() {
|
||||||
|
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||||
@@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() && cc5m.Int() > 0 {
|
||||||
|
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
}
|
||||||
|
if cc1h.Exists() && cc1h.Int() > 0 {
|
||||||
|
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
|
||||||
|
// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。
|
||||||
|
func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
|
||||||
|
// Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别
|
||||||
|
if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 {
|
||||||
|
usage.CacheCreation5mTokens = usage.CacheCreationInputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||||
|
if total == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch target {
|
||||||
|
case "1h":
|
||||||
|
if usage.CacheCreation1hTokens == total {
|
||||||
|
return false // 已经全是 1h
|
||||||
|
}
|
||||||
|
usage.CacheCreation1hTokens = total
|
||||||
|
usage.CacheCreation5mTokens = 0
|
||||||
|
default: // "5m"
|
||||||
|
if usage.CacheCreation5mTokens == total {
|
||||||
|
return false // 已经全是 5m
|
||||||
|
}
|
||||||
|
usage.CacheCreation5mTokens = total
|
||||||
|
usage.CacheCreation1hTokens = 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
|
||||||
|
// usageObj 是 usage JSON 对象(map[string]any)。
|
||||||
|
func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
|
||||||
|
ccObj, ok := usageObj["cache_creation"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
|
||||||
|
v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
|
||||||
|
total := v5m + v1h
|
||||||
|
if total == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch target {
|
||||||
|
case "1h":
|
||||||
|
ccObj["ephemeral_1h_input_tokens"] = total
|
||||||
|
ccObj["ephemeral_5m_input_tokens"] = float64(0)
|
||||||
|
default: // "5m"
|
||||||
|
ccObj["ephemeral_5m_input_tokens"] = total
|
||||||
|
ccObj["ephemeral_1h_input_tokens"] = float64(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
return nil, fmt.Errorf("parse response: %w", err)
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() || cc1h.Exists() {
|
||||||
|
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
|
|
||||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||||
if response.Usage.CacheReadInputTokens == 0 {
|
if response.Usage.CacheReadInputTokens == 0 {
|
||||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||||
@@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||||
|
if account.IsCacheTTLOverrideEnabled() {
|
||||||
|
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||||
|
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||||
|
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||||
|
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
|
||||||
|
body = newBody
|
||||||
|
}
|
||||||
|
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil {
|
||||||
|
body = newBody
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 如果有模型映射,替换响应中的model字段
|
// 如果有模型映射,替换响应中的model字段
|
||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
@@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
|
cacheTTLOverridden := false
|
||||||
|
if account.IsCacheTTLOverrideEnabled() {
|
||||||
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
|
}
|
||||||
|
|
||||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||||
multiplier := s.cfg.Default.RateMultiplier
|
multiplier := s.cfg.Default.RateMultiplier
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||||
@@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
} else {
|
} else {
|
||||||
// Token 计费
|
// Token 计费
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||||
@@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
@@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: imageSize,
|
ImageSize: imageSize,
|
||||||
MediaType: mediaType,
|
MediaType: mediaType,
|
||||||
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
|
cacheTTLOverridden := false
|
||||||
|
if account.IsCacheTTLOverrideEnabled() {
|
||||||
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
|
}
|
||||||
|
|
||||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||||
multiplier := s.cfg.Default.RateMultiplier
|
multiplier := s.cfg.Default.RateMultiplier
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||||
@@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
} else {
|
} else {
|
||||||
// Token 计费(使用长上下文计费方法)
|
// Token 计费(使用长上下文计费方法)
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||||
@@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
@@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: imageSize,
|
ImageSize: imageSize,
|
||||||
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
|
|
||||||
incomingBeta := req.Header.Get("anthropic-beta")
|
incomingBeta := req.Header.Get("anthropic-beta")
|
||||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||||
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
|
drop := map[string]struct{}{claude.BetaContext1M: {}}
|
||||||
|
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||||
} else {
|
} else {
|
||||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||||
if clientBetaHeader == "" {
|
if clientBetaHeader == "" {
|
||||||
@@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||||
beta = beta + "," + claude.BetaTokenCounting
|
beta = beta + "," + claude.BetaTokenCounting
|
||||||
}
|
}
|
||||||
req.Header.Set("anthropic-beta", beta)
|
req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
|
|||||||
@@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
|
|||||||
require.Equal(t, 60, usage.CacheReadInputTokens)
|
require.Equal(t, 60, usage.CacheReadInputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) {
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
|
||||||
|
// 先在 message_start 中写入非零 5m/1h 明细
|
||||||
|
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage)
|
||||||
|
require.Equal(t, 30, usage.CacheCreation5mTokens)
|
||||||
|
require.Equal(t, 70, usage.CacheCreation1hTokens)
|
||||||
|
|
||||||
|
// 后续 delta 带默认 0,不应覆盖已有非零值
|
||||||
|
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage)
|
||||||
|
require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细")
|
||||||
|
require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细")
|
||||||
|
require.Equal(t, 12, usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
|
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
|
||||||
svc := newMinimalGatewayService()
|
svc := newMinimalGatewayService()
|
||||||
usage := &ClaudeUsage{}
|
usage := &ClaudeUsage{}
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
|
|||||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
type OpenAIOAuthClient interface {
|
type OpenAIOAuthClient interface {
|
||||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||||
|
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||||
|
|||||||
@@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
|
|||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
// Strip parameters unsupported by codex models via the Responses API.
|
||||||
delete(reqBody, "max_output_tokens")
|
for _, key := range []string{
|
||||||
result.Modified = true
|
"max_output_tokens",
|
||||||
}
|
"max_completion_tokens",
|
||||||
if _, ok := reqBody["max_completion_tokens"]; ok {
|
"temperature",
|
||||||
delete(reqBody, "max_completion_tokens")
|
"top_p",
|
||||||
result.Modified = true
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
} {
|
||||||
|
if _, ok := reqBody[key]; ok {
|
||||||
|
delete(reqBody, key)
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if normalizeCodexTools(reqBody) {
|
if normalizeCodexTools(reqBody) {
|
||||||
|
|||||||
@@ -2,13 +2,20 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||||
|
|
||||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||||
type OpenAIOAuthService struct {
|
type OpenAIOAuthService struct {
|
||||||
sessionStore *openai.SessionStore
|
sessionStore *openai.SessionStore
|
||||||
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
|||||||
type OpenAIExchangeCodeInput struct {
|
type OpenAIExchangeCodeInput struct {
|
||||||
SessionID string
|
SessionID string
|
||||||
Code string
|
Code string
|
||||||
|
State string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
}
|
}
|
||||||
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||||
}
|
}
|
||||||
|
if input.State == "" {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
|
||||||
|
}
|
||||||
|
|
||||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||||
proxyURL := session.ProxyURL
|
proxyURL := session.ProxyURL
|
||||||
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
|
|
||||||
// RefreshToken refreshes an OpenAI OAuth token
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
||||||
|
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||||
|
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
|||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for an OpenAI account
|
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||||
if !account.IsOpenAI() {
|
if strings.TrimSpace(sessionToken) == "" {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := account.GetOpenAIRefreshToken()
|
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
|
||||||
|
client := newOpenAIOAuthHTTPClient(proxyURL)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var sessionResp struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
Expires string `json:"expires"`
|
||||||
|
User struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"user"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
||||||
|
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||||
|
if strings.TrimSpace(sessionResp.Expires) != "" {
|
||||||
|
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
||||||
|
expiresAt = parsed.Unix()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresIn := expiresAt - time.Now().Unix()
|
||||||
|
if expiresIn < 0 {
|
||||||
|
expiresIn = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpenAITokenInfo{
|
||||||
|
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||||
|
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||||
|
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
||||||
|
}
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := account.GetCredential("refresh_token")
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||||
}
|
}
|
||||||
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
clientID := account.GetCredential("client_id")
|
||||||
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAccountCredentials builds credentials map from token info
|
// BuildAccountCredentials builds credentials map from token info
|
||||||
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
|||||||
func (s *OpenAIOAuthService) Stop() {
|
func (s *OpenAIOAuthService) Stop() {
|
||||||
s.sessionStore.Stop()
|
s.sessionStore.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||||
|
if proxyID == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||||
|
}
|
||||||
|
if proxy == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return proxy.URL(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
if strings.TrimSpace(proxyURL) != "" {
|
||||||
|
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
|
||||||
|
transport.Proxy = http.ProxyURL(parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 120 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthClientNoopStub struct{}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodGet, r.Method)
|
||||||
|
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := openAISoraSessionAuthURL
|
||||||
|
openAISoraSessionAuthURL = server.URL
|
||||||
|
defer func() { openAISoraSessionAuthURL = origin }()
|
||||||
|
|
||||||
|
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Equal(t, "at-token", info.AccessToken)
|
||||||
|
require.Equal(t, "demo@example.com", info.Email)
|
||||||
|
require.Greater(t, info.ExpiresAt, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := openAISoraSessionAuthURL
|
||||||
|
openAISoraSessionAuthURL = server.URL
|
||||||
|
defer func() { openAISoraSessionAuthURL = origin }()
|
||||||
|
|
||||||
|
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing access token")
|
||||||
|
}
|
||||||
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthClientStateStub struct {
|
||||||
|
exchangeCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||||
|
return &openai.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
RefreshToken: "rt",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "oauth state is required")
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
State: "wrong-state",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid oauth state")
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
State: "expected-state",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Equal(t, "at", info.AccessToken)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
|
||||||
|
_, ok := svc.sessionStore.Get("sid")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if account.Platform == PlatformSora {
|
||||||
|
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||||
|
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||||
|
refreshFailed = true
|
||||||
|
} else if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
p.metrics.refreshFailure.Add(1)
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if account.Platform == PlatformSora {
|
||||||
|
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
||||||
|
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||||
|
refreshFailed = true
|
||||||
|
} else if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
p.metrics.refreshFailure.Add(1)
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true
|
refreshFailed = true
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
|
|||||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: opsAccountsPageSize,
|
PageSize: opsAccountsPageSize,
|
||||||
}, platformFilter, "", "", "")
|
}, platformFilter, "", "", "", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,14 +28,15 @@ var (
|
|||||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||||
type LiteLLMModelPricing struct {
|
type LiteLLMModelPricing struct {
|
||||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||||
LiteLLMProvider string `json:"litellm_provider"`
|
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||||
Mode string `json:"mode"`
|
LiteLLMProvider string `json:"litellm_provider"`
|
||||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
Mode string `json:"mode"`
|
||||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||||
|
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||||
}
|
}
|
||||||
|
|
||||||
// PricingRemoteClient 远程价格数据获取接口
|
// PricingRemoteClient 远程价格数据获取接口
|
||||||
@@ -46,14 +47,15 @@ type PricingRemoteClient interface {
|
|||||||
|
|
||||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||||
type LiteLLMRawEntry struct {
|
type LiteLLMRawEntry struct {
|
||||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||||
LiteLLMProvider string `json:"litellm_provider"`
|
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||||
Mode string `json:"mode"`
|
LiteLLMProvider string `json:"litellm_provider"`
|
||||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
Mode string `json:"mode"`
|
||||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||||
|
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PricingService 动态价格服务
|
// PricingService 动态价格服务
|
||||||
@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
|||||||
if entry.CacheCreationInputTokenCost != nil {
|
if entry.CacheCreationInputTokenCost != nil {
|
||||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||||
}
|
}
|
||||||
|
if entry.CacheCreationInputTokenCostAbove1hr != nil {
|
||||||
|
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
|
||||||
|
}
|
||||||
if entry.CacheReadInputTokenCost != nil {
|
if entry.CacheReadInputTokenCost != nil {
|
||||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,11 @@ type ProxyWithAccountCount struct {
|
|||||||
CountryCode string
|
CountryCode string
|
||||||
Region string
|
Region string
|
||||||
City string
|
City string
|
||||||
|
QualityStatus string
|
||||||
|
QualityScore *int
|
||||||
|
QualityGrade string
|
||||||
|
QualitySummary string
|
||||||
|
QualityChecked *int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyAccountSummary struct {
|
type ProxyAccountSummary struct {
|
||||||
|
|||||||
@@ -6,15 +6,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ProxyLatencyInfo struct {
|
type ProxyLatencyInfo struct {
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
IPAddress string `json:"ip_address,omitempty"`
|
IPAddress string `json:"ip_address,omitempty"`
|
||||||
Country string `json:"country,omitempty"`
|
Country string `json:"country,omitempty"`
|
||||||
CountryCode string `json:"country_code,omitempty"`
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
Region string `json:"region,omitempty"`
|
Region string `json:"region,omitempty"`
|
||||||
City string `json:"city,omitempty"`
|
City string `json:"city,omitempty"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
QualityStatus string `json:"quality_status,omitempty"`
|
||||||
|
QualityScore *int `json:"quality_score,omitempty"`
|
||||||
|
QualityGrade string `json:"quality_grade,omitempty"`
|
||||||
|
QualitySummary string `json:"quality_summary,omitempty"`
|
||||||
|
QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"`
|
||||||
|
QualityCFRay string `json:"quality_cf_ray,omitempty"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyLatencyCache interface {
|
type ProxyLatencyCache interface {
|
||||||
|
|||||||
@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||||
|
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||||
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||||
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
|
||||||
|
windowEnd := result.resetAt
|
||||||
|
if result.fiveHourReset != nil {
|
||||||
|
windowEnd = *result.fiveHourReset
|
||||||
|
}
|
||||||
|
windowStart := windowEnd.Add(-5 * time.Hour)
|
||||||
|
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||||
|
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
|
||||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||||
|
|
||||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
// 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||||
if resetTimestamp == "" {
|
if resetTimestamp == "" {
|
||||||
switch account.Platform {
|
switch account.Platform {
|
||||||
case PlatformOpenAI:
|
case PlatformOpenAI:
|
||||||
@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
|
||||||
|
type anthropic429Result struct {
|
||||||
|
resetAt time.Time // The correct reset time to use for SetRateLimited
|
||||||
|
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
|
||||||
|
// to determine which window (5h or 7d) actually triggered the 429.
|
||||||
|
//
|
||||||
|
// Headers used:
|
||||||
|
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
|
||||||
|
// - anthropic-ratelimit-unified-5h-reset
|
||||||
|
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
|
||||||
|
// - anthropic-ratelimit-unified-7d-reset
|
||||||
|
//
|
||||||
|
// Returns nil when the per-window headers are absent (caller should fall back to
|
||||||
|
// the aggregated anthropic-ratelimit-unified-reset header).
|
||||||
|
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
|
||||||
|
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
|
||||||
|
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
|
||||||
|
|
||||||
|
if reset5hStr == "" && reset7dStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var reset5h, reset7d *time.Time
|
||||||
|
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
|
||||||
|
t := time.Unix(ts, 0)
|
||||||
|
reset5h = &t
|
||||||
|
}
|
||||||
|
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
|
||||||
|
t := time.Unix(ts, 0)
|
||||||
|
reset7d = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
|
||||||
|
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
|
||||||
|
|
||||||
|
slog.Info("anthropic_429_window_analysis",
|
||||||
|
"is_5h_exceeded", is5hExceeded,
|
||||||
|
"is_7d_exceeded", is7dExceeded,
|
||||||
|
"reset_5h", reset5hStr,
|
||||||
|
"reset_7d", reset7dStr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Select the correct reset time based on which window(s) are exceeded.
|
||||||
|
var chosen *time.Time
|
||||||
|
switch {
|
||||||
|
case is5hExceeded && is7dExceeded:
|
||||||
|
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
|
||||||
|
chosen = reset7d
|
||||||
|
if chosen == nil {
|
||||||
|
chosen = reset5h
|
||||||
|
}
|
||||||
|
case is5hExceeded:
|
||||||
|
chosen = reset5h
|
||||||
|
case is7dExceeded:
|
||||||
|
chosen = reset7d
|
||||||
|
default:
|
||||||
|
// Neither flag clearly exceeded — pick the sooner reset as best guess
|
||||||
|
chosen = pickSooner(reset5h, reset7d)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chosen == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
|
||||||
|
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
|
||||||
|
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
|
||||||
|
prefix := "anthropic-ratelimit-unified-" + window + "-"
|
||||||
|
|
||||||
|
// Check surpassed-threshold first (most explicit signal)
|
||||||
|
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to utilization >= 1.0
|
||||||
|
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
|
||||||
|
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
|
||||||
|
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickSooner returns whichever of the two time pointers is earlier.
|
||||||
|
// If only one is non-nil, it is returned. If both are nil, returns nil.
|
||||||
|
func pickSooner(a, b *time.Time) *time.Time {
|
||||||
|
switch {
|
||||||
|
case a != nil && b != nil:
|
||||||
|
if a.Before(*b) {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
case a != nil:
|
||||||
|
return a
|
||||||
|
default:
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||||
// OpenAI 的 usage_limit_reached 错误格式:
|
// OpenAI 的 usage_limit_reached 错误格式:
|
||||||
//
|
//
|
||||||
|
|||||||
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
|
||||||
|
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||||
|
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
|
||||||
|
// fiveHourReset should still be populated for session window calculation
|
||||||
|
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||||
|
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
|
||||||
|
result := calculateAnthropic429ResetTime(http.Header{})
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
|
||||||
|
if result.fiveHourReset != nil {
|
||||||
|
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAnthropicWindowExceeded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
headers http.Header
|
||||||
|
window string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "utilization above 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
|
||||||
|
window: "5h",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utilization exactly 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
|
||||||
|
window: "5h",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utilization below 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
|
||||||
|
window: "5h",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold true",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
|
||||||
|
window: "7d",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold True (case insensitive)",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
|
||||||
|
window: "7d",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold false",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
|
||||||
|
window: "7d",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no headers",
|
||||||
|
headers: http.Header{},
|
||||||
|
window: "5h",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := isAnthropicWindowExceeded(tc.headers, tc.window)
|
||||||
|
if got != tc.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tc.expected, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertAnthropicResult is a test helper that verifies the result is non-nil and
|
||||||
|
// has the expected resetAt unix timestamp.
|
||||||
|
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
|
||||||
|
t.Helper()
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
return // unreachable, but satisfies staticcheck SA5011
|
||||||
|
}
|
||||||
|
want := time.Unix(wantUnix, 0)
|
||||||
|
if !result.resetAt.Equal(want) {
|
||||||
|
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeHeader(key, value string) http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set(key, value)
|
||||||
|
return h
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
260
backend/internal/service/sora_curl_cffi_sidecar.go
Normal file
260
backend/internal/service/sora_curl_cffi_sidecar.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||||
|
)
|
||||||
|
|
||||||
|
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
|
||||||
|
|
||||||
|
type soraCurlCFFISidecarRequest struct {
|
||||||
|
Method string `json:"method"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Headers map[string][]string `json:"headers,omitempty"`
|
||||||
|
BodyBase64 string `json:"body_base64,omitempty"`
|
||||||
|
ProxyURL string `json:"proxy_url,omitempty"`
|
||||||
|
SessionKey string `json:"session_key,omitempty"`
|
||||||
|
Impersonate string `json:"impersonate,omitempty"`
|
||||||
|
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraCurlCFFISidecarResponse struct {
|
||||||
|
StatusCode int `json:"status_code"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
Headers map[string]any `json:"headers"`
|
||||||
|
BodyBase64 string `json:"body_base64"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||||
|
if req == nil || req.URL == nil {
|
||||||
|
return nil, errors.New("request url is nil")
|
||||||
|
}
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return nil, errors.New("sora curl_cffi sidecar config is nil")
|
||||||
|
}
|
||||||
|
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||||
|
return nil, errors.New("sora curl_cffi sidecar is disabled")
|
||||||
|
}
|
||||||
|
endpoint := c.curlCFFISidecarEndpoint()
|
||||||
|
if endpoint == "" {
|
||||||
|
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := readAndRestoreRequestBody(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := make(map[string][]string, len(req.Header)+1)
|
||||||
|
for key, vals := range req.Header {
|
||||||
|
copied := make([]string, len(vals))
|
||||||
|
copy(copied, vals)
|
||||||
|
headers[key] = copied
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Host) != "" {
|
||||||
|
if _, ok := headers["Host"]; !ok {
|
||||||
|
headers["Host"] = []string{req.Host}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := soraCurlCFFISidecarRequest{
|
||||||
|
Method: req.Method,
|
||||||
|
URL: req.URL.String(),
|
||||||
|
Headers: headers,
|
||||||
|
ProxyURL: strings.TrimSpace(proxyURL),
|
||||||
|
SessionKey: c.sidecarSessionKey(account, proxyURL),
|
||||||
|
Impersonate: c.curlCFFIImpersonate(),
|
||||||
|
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
|
||||||
|
}
|
||||||
|
if len(bodyBytes) > 0 {
|
||||||
|
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
|
||||||
|
}
|
||||||
|
sidecarReq.Header.Set("Content-Type", "application/json")
|
||||||
|
sidecarReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
|
||||||
|
sidecarResp, err := httpClient.Do(sidecarReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = sidecarResp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
|
||||||
|
}
|
||||||
|
if sidecarResp.StatusCode != http.StatusOK {
|
||||||
|
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payloadResp soraCurlCFFISidecarResponse
|
||||||
|
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
|
||||||
|
}
|
||||||
|
statusCode := payloadResp.StatusCode
|
||||||
|
if statusCode <= 0 {
|
||||||
|
statusCode = payloadResp.Status
|
||||||
|
}
|
||||||
|
if statusCode <= 0 {
|
||||||
|
return nil, errors.New("sora curl_cffi sidecar response missing status code")
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBody := []byte(payloadResp.Body)
|
||||||
|
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
|
||||||
|
}
|
||||||
|
responseBody = decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
respHeaders := make(http.Header)
|
||||||
|
for key, rawVal := range payloadResp.Headers {
|
||||||
|
for _, v := range convertSidecarHeaderValue(rawVal) {
|
||||||
|
respHeaders.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Header: respHeaders,
|
||||||
|
Body: io.NopCloser(bytes.NewReader(responseBody)),
|
||||||
|
ContentLength: int64(len(responseBody)),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
|
||||||
|
if req == nil || req.Body == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.ContentLength = int64(len(bodyBytes))
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
|
||||||
|
parsed.Path = "/request"
|
||||||
|
}
|
||||||
|
return parsed.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
|
||||||
|
if timeoutSeconds <= 0 {
|
||||||
|
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
return timeoutSeconds
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) curlCFFIImpersonate() string {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return "chrome131"
|
||||||
|
}
|
||||||
|
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
|
||||||
|
if impersonate == "" {
|
||||||
|
return "chrome131"
|
||||||
|
}
|
||||||
|
return impersonate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return 3600
|
||||||
|
}
|
||||||
|
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
|
||||||
|
if ttl < 0 {
|
||||||
|
return 3600
|
||||||
|
}
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertSidecarHeaderValue(raw any) []string {
|
||||||
|
switch val := raw.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(val) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{val}
|
||||||
|
case []any:
|
||||||
|
out := make([]string, 0, len(val))
|
||||||
|
for _, item := range val {
|
||||||
|
s := strings.TrimSpace(fmt.Sprint(item))
|
||||||
|
if s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []string:
|
||||||
|
out := make([]string, 0, len(val))
|
||||||
|
for _, item := range val {
|
||||||
|
if strings.TrimSpace(item) != "" {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
s := strings.TrimSpace(fmt.Sprint(val))
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{s}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
"mime"
|
"mime"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,6 +25,9 @@ import (
|
|||||||
const soraImageInputMaxBytes = 20 << 20
|
const soraImageInputMaxBytes = 20 << 20
|
||||||
const soraImageInputMaxRedirects = 3
|
const soraImageInputMaxRedirects = 3
|
||||||
const soraImageInputTimeout = 20 * time.Second
|
const soraImageInputTimeout = 20 * time.Second
|
||||||
|
const soraVideoInputMaxBytes = 200 << 20
|
||||||
|
const soraVideoInputMaxRedirects = 3
|
||||||
|
const soraVideoInputTimeout = 60 * time.Second
|
||||||
|
|
||||||
var soraImageSizeMap = map[string]string{
|
var soraImageSizeMap = map[string]string{
|
||||||
"gpt-image": "360",
|
"gpt-image": "360",
|
||||||
@@ -61,6 +66,36 @@ type SoraGatewayService struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type soraWatermarkOptions struct {
|
||||||
|
Enabled bool
|
||||||
|
ParseMethod string
|
||||||
|
ParseURL string
|
||||||
|
ParseToken string
|
||||||
|
FallbackOnFailure bool
|
||||||
|
DeletePost bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraCharacterOptions struct {
|
||||||
|
SetPublic bool
|
||||||
|
DeleteAfterGenerate bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraCharacterFlowResult struct {
|
||||||
|
CameoID string
|
||||||
|
CharacterID string
|
||||||
|
Username string
|
||||||
|
DisplayName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
|
||||||
|
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
|
||||||
|
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
|
||||||
|
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
|
||||||
|
|
||||||
|
type soraPreflightChecker interface {
|
||||||
|
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
func NewSoraGatewayService(
|
func NewSoraGatewayService(
|
||||||
soraClient SoraClient,
|
soraClient SoraClient,
|
||||||
mediaStorage *SoraMediaStorage,
|
mediaStorage *SoraMediaStorage,
|
||||||
@@ -112,29 +147,133 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||||
}
|
}
|
||||||
if modelCfg.Type == "prompt_enhance" {
|
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
|
||||||
return nil, fmt.Errorf("prompt-enhance not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||||
if strings.TrimSpace(prompt) == "" {
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
imageInput = strings.TrimSpace(imageInput)
|
||||||
|
videoInput = strings.TrimSpace(videoInput)
|
||||||
|
remixTargetID = strings.TrimSpace(remixTargetID)
|
||||||
|
|
||||||
|
if videoInput != "" && modelCfg.Type != "video" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
|
||||||
|
return nil, errors.New("video input only supports video models")
|
||||||
|
}
|
||||||
|
if videoInput != "" && imageInput != "" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
|
||||||
|
return nil, errors.New("image input and video input cannot be used together")
|
||||||
|
}
|
||||||
|
characterOnly := videoInput != "" && prompt == ""
|
||||||
|
if modelCfg.Type == "prompt_enhance" && prompt == "" {
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||||
return nil, errors.New("prompt is required")
|
return nil, errors.New("prompt is required")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(videoInput) != "" {
|
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||||
return nil, errors.New("video input not supported")
|
return nil, errors.New("prompt is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
|
||||||
|
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelCfg.Type == "prompt_enhance" {
|
||||||
|
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
|
||||||
|
if err != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(enhancedPrompt)
|
||||||
|
if content == "" {
|
||||||
|
content = prompt
|
||||||
|
}
|
||||||
|
var firstTokenMs *int
|
||||||
|
if clientStream {
|
||||||
|
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||||
|
if streamErr != nil {
|
||||||
|
return nil, streamErr
|
||||||
|
}
|
||||||
|
firstTokenMs = ms
|
||||||
|
} else if c != nil {
|
||||||
|
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||||
|
}
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Model: reqModel,
|
||||||
|
Stream: clientStream,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
Usage: ClaudeUsage{},
|
||||||
|
MediaType: "prompt",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
characterOpts := parseSoraCharacterOptions(reqBody)
|
||||||
|
watermarkOpts := parseSoraWatermarkOptions(reqBody)
|
||||||
|
var characterResult *soraCharacterFlowResult
|
||||||
|
if videoInput != "" {
|
||||||
|
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
|
||||||
|
if videoErr != nil {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
|
||||||
|
return nil, videoErr
|
||||||
|
}
|
||||||
|
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
|
||||||
|
if videoErr != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
|
||||||
|
characterID := strings.TrimSpace(characterResult.CharacterID)
|
||||||
|
defer func() {
|
||||||
|
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancelCleanup()
|
||||||
|
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
|
||||||
|
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
if characterOnly {
|
||||||
|
content := "角色创建成功"
|
||||||
|
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||||
|
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
|
||||||
|
}
|
||||||
|
var firstTokenMs *int
|
||||||
|
if clientStream {
|
||||||
|
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||||
|
if streamErr != nil {
|
||||||
|
return nil, streamErr
|
||||||
|
}
|
||||||
|
firstTokenMs = ms
|
||||||
|
} else if c != nil {
|
||||||
|
resp := buildSoraNonStreamResponse(content, reqModel)
|
||||||
|
if characterResult != nil {
|
||||||
|
resp["character_id"] = characterResult.CharacterID
|
||||||
|
resp["cameo_id"] = characterResult.CameoID
|
||||||
|
resp["character_username"] = characterResult.Username
|
||||||
|
resp["character_display_name"] = characterResult.DisplayName
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Model: reqModel,
|
||||||
|
Stream: clientStream,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
Usage: ClaudeUsage{},
|
||||||
|
MediaType: "prompt",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||||
|
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var imageData []byte
|
var imageData []byte
|
||||||
imageFilename := ""
|
imageFilename := ""
|
||||||
if strings.TrimSpace(imageInput) != "" {
|
if imageInput != "" {
|
||||||
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||||
@@ -164,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
MediaID: mediaID,
|
MediaID: mediaID,
|
||||||
})
|
})
|
||||||
case "video":
|
case "video":
|
||||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
|
||||||
Prompt: prompt,
|
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
|
||||||
Orientation: modelCfg.Orientation,
|
Prompt: formatSoraStoryboardPrompt(prompt),
|
||||||
Frames: modelCfg.Frames,
|
Orientation: modelCfg.Orientation,
|
||||||
Model: modelCfg.Model,
|
Frames: modelCfg.Frames,
|
||||||
Size: modelCfg.Size,
|
Model: modelCfg.Model,
|
||||||
MediaID: mediaID,
|
Size: modelCfg.Size,
|
||||||
RemixTargetID: remixTargetID,
|
MediaID: mediaID,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Orientation: modelCfg.Orientation,
|
||||||
|
Frames: modelCfg.Frames,
|
||||||
|
Model: modelCfg.Model,
|
||||||
|
Size: modelCfg.Size,
|
||||||
|
MediaID: mediaID,
|
||||||
|
RemixTargetID: remixTargetID,
|
||||||
|
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||||
|
})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||||
}
|
}
|
||||||
@@ -185,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
var mediaURLs []string
|
var mediaURLs []string
|
||||||
|
videoGenerationID := ""
|
||||||
mediaType := modelCfg.Type
|
mediaType := modelCfg.Type
|
||||||
imageCount := 0
|
imageCount := 0
|
||||||
imageSize := ""
|
imageSize := ""
|
||||||
@@ -198,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
imageCount = len(urls)
|
imageCount = len(urls)
|
||||||
imageSize = soraImageSizeFromModel(reqModel)
|
imageSize = soraImageSizeFromModel(reqModel)
|
||||||
case "video":
|
case "video":
|
||||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
|
||||||
if pollErr != nil {
|
if pollErr != nil {
|
||||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||||
}
|
}
|
||||||
mediaURLs = urls
|
if videoStatus != nil {
|
||||||
|
mediaURLs = videoStatus.URLs
|
||||||
|
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
mediaType = "prompt"
|
mediaType = "prompt"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
watermarkPostID := ""
|
||||||
|
if modelCfg.Type == "video" && watermarkOpts.Enabled {
|
||||||
|
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
|
||||||
|
if watermarkErr != nil {
|
||||||
|
if !watermarkOpts.FallbackOnFailure {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
|
||||||
|
} else if strings.TrimSpace(watermarkURL) != "" {
|
||||||
|
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
|
||||||
|
watermarkPostID = strings.TrimSpace(postID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||||
@@ -217,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||||
|
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||||
|
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
content := buildSoraContent(mediaType, finalURLs)
|
content := buildSoraContent(mediaType, finalURLs)
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
@@ -265,9 +439,270 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
|||||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
|
||||||
|
opts := soraWatermarkOptions{
|
||||||
|
Enabled: parseBoolWithDefault(body, "watermark_free", false),
|
||||||
|
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
|
||||||
|
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
|
||||||
|
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
|
||||||
|
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
|
||||||
|
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
|
||||||
|
}
|
||||||
|
if opts.ParseMethod == "" {
|
||||||
|
opts.ParseMethod = "third_party"
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||||
|
return soraCharacterOptions{
|
||||||
|
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
|
||||||
|
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||||
|
if body == nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
val, ok := body[key]
|
||||||
|
if !ok {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
switch typed := val.(type) {
|
||||||
|
case bool:
|
||||||
|
return typed
|
||||||
|
case int:
|
||||||
|
return typed != 0
|
||||||
|
case int32:
|
||||||
|
return typed != 0
|
||||||
|
case int64:
|
||||||
|
return typed != 0
|
||||||
|
case float64:
|
||||||
|
return typed != 0
|
||||||
|
case string:
|
||||||
|
typed = strings.ToLower(strings.TrimSpace(typed))
|
||||||
|
if typed == "true" || typed == "1" || typed == "yes" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if typed == "false" || typed == "0" || typed == "no" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||||
|
if body == nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
val, ok := body[key]
|
||||||
|
if !ok {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
if str, ok := val.(string); ok {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraCameoIDs(body map[string]any) []string {
|
||||||
|
if body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := body["cameo_ids"]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case []string:
|
||||||
|
out := make([]string, 0, len(typed))
|
||||||
|
for _, item := range typed {
|
||||||
|
item = strings.TrimSpace(item)
|
||||||
|
if item != "" {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []any:
|
||||||
|
out := make([]string, 0, len(typed))
|
||||||
|
for _, item := range typed {
|
||||||
|
str, ok := item.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
str = strings.TrimSpace(str)
|
||||||
|
if str != "" {
|
||||||
|
out = append(out, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
|
||||||
|
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
|
||||||
|
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = "Character"
|
||||||
|
}
|
||||||
|
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
|
||||||
|
if profileAssetURL == "" {
|
||||||
|
return nil, errors.New("profile asset url not found in cameo status")
|
||||||
|
}
|
||||||
|
|
||||||
|
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
instructionSet := cameoStatus.InstructionSetHint
|
||||||
|
if instructionSet == nil {
|
||||||
|
instructionSet = cameoStatus.InstructionSet
|
||||||
|
}
|
||||||
|
|
||||||
|
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
|
||||||
|
CameoID: strings.TrimSpace(cameoID),
|
||||||
|
Username: username,
|
||||||
|
DisplayName: displayName,
|
||||||
|
ProfileAssetPointer: assetPointer,
|
||||||
|
InstructionSet: instructionSet,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.SetPublic {
|
||||||
|
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &soraCharacterFlowResult{
|
||||||
|
CameoID: strings.TrimSpace(cameoID),
|
||||||
|
CharacterID: strings.TrimSpace(characterID),
|
||||||
|
Username: strings.TrimSpace(username),
|
||||||
|
DisplayName: displayName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||||
|
timeout := 10 * time.Minute
|
||||||
|
interval := 5 * time.Second
|
||||||
|
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
|
||||||
|
if maxAttempts < 1 {
|
||||||
|
maxAttempts = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
consecutiveErrors := 0
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
consecutiveErrors++
|
||||||
|
if consecutiveErrors >= 3 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attempt < maxAttempts-1 {
|
||||||
|
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||||
|
return nil, sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
consecutiveErrors = 0
|
||||||
|
if status == nil {
|
||||||
|
if attempt < maxAttempts-1 {
|
||||||
|
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||||
|
return nil, sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
|
||||||
|
statusMessage := strings.TrimSpace(status.StatusMessage)
|
||||||
|
if currentStatus == "failed" {
|
||||||
|
if statusMessage == "" {
|
||||||
|
statusMessage = "character creation failed"
|
||||||
|
}
|
||||||
|
return nil, errors.New(statusMessage)
|
||||||
|
}
|
||||||
|
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
if attempt < maxAttempts-1 {
|
||||||
|
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||||
|
return nil, sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
|
||||||
|
}
|
||||||
|
return nil, errors.New("cameo processing timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func processSoraCharacterUsername(usernameHint string) string {
|
||||||
|
usernameHint = strings.TrimSpace(usernameHint)
|
||||||
|
if usernameHint == "" {
|
||||||
|
usernameHint = "character"
|
||||||
|
}
|
||||||
|
if strings.Contains(usernameHint, ".") {
|
||||||
|
parts := strings.Split(usernameHint, ".")
|
||||||
|
usernameHint = strings.TrimSpace(parts[len(parts)-1])
|
||||||
|
}
|
||||||
|
if usernameHint == "" {
|
||||||
|
usernameHint = "character"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||||
|
generationID = strings.TrimSpace(generationID)
|
||||||
|
if generationID == "" {
|
||||||
|
return "", "", errors.New("generation id is required for watermark-free mode")
|
||||||
|
}
|
||||||
|
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
postID = strings.TrimSpace(postID)
|
||||||
|
if postID == "" {
|
||||||
|
return "", "", errors.New("watermark-free publish returned empty post id")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch opts.ParseMethod {
|
||||||
|
case "custom":
|
||||||
|
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
|
||||||
|
if parseErr != nil {
|
||||||
|
return "", postID, parseErr
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(urlVal), postID, nil
|
||||||
|
case "", "third_party":
|
||||||
|
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
|
||||||
|
default:
|
||||||
|
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401, 402, 403, 429, 529:
|
case 401, 402, 403, 404, 429, 529:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return statusCode >= 500
|
return statusCode >= 500
|
||||||
@@ -434,7 +869,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType,
|
|||||||
}
|
}
|
||||||
if stream {
|
if stream {
|
||||||
flusher, _ := c.Writer.(http.Flusher)
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorData := map[string]any{
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(errorData)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||||
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
||||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||||
if flusher != nil {
|
if flusher != nil {
|
||||||
@@ -460,7 +906,15 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
|||||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||||
}
|
}
|
||||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
var responseHeaders http.Header
|
||||||
|
if upstreamErr.Headers != nil {
|
||||||
|
responseHeaders = upstreamErr.Headers.Clone()
|
||||||
|
}
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: upstreamErr.StatusCode,
|
||||||
|
ResponseBody: upstreamErr.Body,
|
||||||
|
ResponseHeaders: responseHeaders,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
msg := upstreamErr.Message
|
msg := upstreamErr.Message
|
||||||
if override := soraProErrorMessage(model, msg); override != "" {
|
if override := soraProErrorMessage(model, msg); override != "" {
|
||||||
@@ -505,7 +959,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
|||||||
return nil, errors.New("sora image generation timeout")
|
return nil, errors.New("sora image generation timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
|
||||||
interval := s.pollInterval()
|
interval := s.pollInterval()
|
||||||
maxAttempts := s.pollMaxAttempts()
|
maxAttempts := s.pollMaxAttempts()
|
||||||
lastPing := time.Now()
|
lastPing := time.Now()
|
||||||
@@ -516,7 +970,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
switch strings.ToLower(status.Status) {
|
switch strings.ToLower(status.Status) {
|
||||||
case "completed", "succeeded":
|
case "completed", "succeeded":
|
||||||
return status.URLs, nil
|
return status, nil
|
||||||
case "failed":
|
case "failed":
|
||||||
if status.ErrorMsg != "" {
|
if status.ErrorMsg != "" {
|
||||||
return nil, errors.New(status.ErrorMsg)
|
return nil, errors.New(status.ErrorMsg)
|
||||||
@@ -620,7 +1074,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
|||||||
return "", "", "", ""
|
return "", "", "", ""
|
||||||
}
|
}
|
||||||
if v, ok := body["remix_target_id"].(string); ok {
|
if v, ok := body["remix_target_id"].(string); ok {
|
||||||
remixTargetID = v
|
remixTargetID = strings.TrimSpace(v)
|
||||||
}
|
}
|
||||||
if v, ok := body["image"].(string); ok {
|
if v, ok := body["image"].(string); ok {
|
||||||
imageInput = v
|
imageInput = v
|
||||||
@@ -661,6 +1115,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
|||||||
prompt = builder.String()
|
prompt = builder.String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if remixTargetID == "" {
|
||||||
|
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
|
||||||
|
}
|
||||||
|
prompt = cleanRemixLinkFromPrompt(prompt)
|
||||||
return prompt, imageInput, videoInput, remixTargetID
|
return prompt, imageInput, videoInput, remixTargetID
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -708,6 +1166,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSoraStoryboardPrompt(prompt string) bool {
|
||||||
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
if prompt == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatSoraStoryboardPrompt(prompt string) string {
|
||||||
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
if prompt == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
firstBracketPos := strings.Index(prompt, "[")
|
||||||
|
instructions := ""
|
||||||
|
if firstBracketPos > 0 {
|
||||||
|
instructions = strings.TrimSpace(prompt[:firstBracketPos])
|
||||||
|
}
|
||||||
|
shots := make([]string, 0, len(matches))
|
||||||
|
for i, match := range matches {
|
||||||
|
if len(match) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
duration := strings.TrimSpace(match[1])
|
||||||
|
scene := strings.TrimSpace(match[2])
|
||||||
|
if scene == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
|
||||||
|
}
|
||||||
|
if len(shots) == 0 {
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
timeline := strings.Join(shots, "\n\n")
|
||||||
|
if instructions == "" {
|
||||||
|
return timeline
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractRemixTargetIDFromPrompt(prompt string) string {
|
||||||
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
if prompt == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanRemixLinkFromPrompt(prompt string) string {
|
||||||
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
if prompt == "" {
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
|
||||||
|
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
|
||||||
|
cleaned = strings.Join(strings.Fields(cleaned), " ")
|
||||||
|
return strings.TrimSpace(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||||
raw := strings.TrimSpace(input)
|
raw := strings.TrimSpace(input)
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
@@ -720,7 +1241,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
|||||||
}
|
}
|
||||||
meta := parts[0]
|
meta := parts[0]
|
||||||
payload := parts[1]
|
payload := parts[1]
|
||||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -739,15 +1260,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
|||||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||||
return downloadSoraImageInput(ctx, raw)
|
return downloadSoraImageInput(ctx, raw)
|
||||||
}
|
}
|
||||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", errors.New("invalid base64 image")
|
return nil, "", errors.New("invalid base64 image")
|
||||||
}
|
}
|
||||||
return decoded, "image.png", nil
|
return decoded, "image.png", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
|
||||||
|
raw := strings.TrimSpace(input)
|
||||||
|
if raw == "" {
|
||||||
|
return nil, errors.New("empty video input")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "data:") {
|
||||||
|
parts := strings.SplitN(raw, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, errors.New("invalid video data url")
|
||||||
|
}
|
||||||
|
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid base64 video")
|
||||||
|
}
|
||||||
|
if len(decoded) == 0 {
|
||||||
|
return nil, errors.New("empty video data")
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||||
|
return downloadSoraVideoInput(ctx, raw)
|
||||||
|
}
|
||||||
|
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid base64 video")
|
||||||
|
}
|
||||||
|
if len(decoded) == 0 {
|
||||||
|
return nil, errors.New("empty video data")
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||||
parsed, err := validateSoraImageURL(rawURL)
|
parsed, err := validateSoraRemoteURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -761,7 +1314,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
|||||||
if len(via) >= soraImageInputMaxRedirects {
|
if len(via) >= soraImageInputMaxRedirects {
|
||||||
return errors.New("too many redirects")
|
return errors.New("too many redirects")
|
||||||
}
|
}
|
||||||
return validateSoraImageURLValue(req.URL)
|
return validateSoraRemoteURLValue(req.URL)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
@@ -784,51 +1337,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
|||||||
return data, filename, nil
|
return data, filename, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
|
||||||
|
parsed, err := validateSoraRemoteURL(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: soraVideoInputTimeout,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) >= soraVideoInputMaxRedirects {
|
||||||
|
return errors.New("too many redirects")
|
||||||
|
}
|
||||||
|
return validateSoraRemoteURLValue(req.URL)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, errors.New("empty video content")
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
return nil, errors.New("invalid max bytes limit")
|
||||||
|
}
|
||||||
|
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||||
|
limited := io.LimitReader(decoder, maxBytes+1)
|
||||||
|
data, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(len(data)) > maxBytes {
|
||||||
|
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSoraRemoteURL(raw string) (*url.URL, error) {
|
||||||
if strings.TrimSpace(raw) == "" {
|
if strings.TrimSpace(raw) == "" {
|
||||||
return nil, errors.New("empty image url")
|
return nil, errors.New("empty remote url")
|
||||||
}
|
}
|
||||||
parsed, err := url.Parse(raw)
|
parsed, err := url.Parse(raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
return nil, fmt.Errorf("invalid remote url: %w", err)
|
||||||
}
|
}
|
||||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
if err := validateSoraRemoteURLValue(parsed); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
func validateSoraRemoteURLValue(parsed *url.URL) error {
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
return errors.New("invalid image url")
|
return errors.New("invalid remote url")
|
||||||
}
|
}
|
||||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||||
if scheme != "http" && scheme != "https" {
|
if scheme != "http" && scheme != "https" {
|
||||||
return errors.New("only http/https image url is allowed")
|
return errors.New("only http/https remote url is allowed")
|
||||||
}
|
}
|
||||||
if parsed.User != nil {
|
if parsed.User != nil {
|
||||||
return errors.New("image url cannot contain userinfo")
|
return errors.New("remote url cannot contain userinfo")
|
||||||
}
|
}
|
||||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return errors.New("image url missing host")
|
return errors.New("remote url missing host")
|
||||||
}
|
}
|
||||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||||
return errors.New("image url is not allowed")
|
return errors.New("remote url is not allowed")
|
||||||
}
|
}
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
if isSoraBlockedIP(ip) {
|
if isSoraBlockedIP(ip) {
|
||||||
return errors.New("image url is not allowed")
|
return errors.New("remote url is not allowed")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := net.LookupIP(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("resolve image url failed: %w", err)
|
return fmt.Errorf("resolve remote url failed: %w", err)
|
||||||
}
|
}
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
if isSoraBlockedIP(ip) {
|
if isSoraBlockedIP(ip) {
|
||||||
return errors.New("image url is not allowed")
|
return errors.New("remote url is not allowed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -4,10 +4,16 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,6 +24,13 @@ type stubSoraClientForPoll struct {
|
|||||||
videoStatus *SoraVideoTaskStatus
|
videoStatus *SoraVideoTaskStatus
|
||||||
imageCalls int
|
imageCalls int
|
||||||
videoCalls int
|
videoCalls int
|
||||||
|
enhanced string
|
||||||
|
enhanceErr error
|
||||||
|
storyboard bool
|
||||||
|
videoReq SoraVideoRequest
|
||||||
|
parseErr error
|
||||||
|
postCalls int
|
||||||
|
deleteCalls int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||||
@@ -28,8 +41,60 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
|
|||||||
return "task-image", nil
|
return "task-image", nil
|
||||||
}
|
}
|
||||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||||
|
s.videoReq = req
|
||||||
return "task-video", nil
|
return "task-video", nil
|
||||||
}
|
}
|
||||||
|
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||||
|
s.storyboard = true
|
||||||
|
return "task-video", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||||
|
return "cameo-1", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||||
|
return &SoraCameoStatus{
|
||||||
|
Status: "finalized",
|
||||||
|
StatusMessage: "Completed",
|
||||||
|
DisplayNameHint: "Character",
|
||||||
|
UsernameHint: "user.character",
|
||||||
|
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||||
|
return []byte("avatar"), nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||||
|
return "asset-pointer", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||||
|
return "character-1", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||||
|
s.postCalls++
|
||||||
|
return "s_post", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||||
|
s.deleteCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||||
|
if s.parseErr != nil {
|
||||||
|
return "", s.parseErr
|
||||||
|
}
|
||||||
|
return "https://example.com/no-watermark.mp4", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||||
|
if s.enhanced != "" {
|
||||||
|
return s.enhanced, s.enhanceErr
|
||||||
|
}
|
||||||
|
return "enhanced prompt", s.enhanceErr
|
||||||
|
}
|
||||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
s.imageCalls++
|
s.imageCalls++
|
||||||
return s.imageStatus, nil
|
return s.imageStatus, nil
|
||||||
@@ -62,6 +127,136 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
|||||||
require.Equal(t, 1, client.imageCalls)
|
require.Equal(t, 1, client.imageCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
enhanced: "cinematic prompt",
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "prompt", result.MediaType)
|
||||||
|
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{"https://example.com/v.mp4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||||
|
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, client.storyboard)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||||
|
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "prompt", result.MediaType)
|
||||||
|
require.Equal(t, 0, client.videoCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{"https://example.com/original.mp4"},
|
||||||
|
GenerationID: "gen_1",
|
||||||
|
},
|
||||||
|
parseErr: errors.New("parse failed"),
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||||
|
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
|
||||||
|
require.Equal(t, 1, client.postCalls)
|
||||||
|
require.Equal(t, 0, client.deleteCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{"https://example.com/original.mp4"},
|
||||||
|
GenerationID: "gen_1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||||
|
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
|
||||||
|
require.Equal(t, 1, client.postCalls)
|
||||||
|
require.Equal(t, 1, client.deleteCalls)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||||
client := &stubSoraClientForPoll{
|
client := &stubSoraClientForPoll{
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
@@ -79,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
|
||||||
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Empty(t, urls)
|
require.Nil(t, status)
|
||||||
require.Contains(t, err.Error(), "reject")
|
require.Contains(t, err.Error(), "reject")
|
||||||
require.Equal(t, 1, client.videoCalls)
|
require.Equal(t, 1, client.videoCalls)
|
||||||
}
|
}
|
||||||
@@ -175,9 +370,65 @@ func TestSoraProErrorMessage(t *testing.T) {
|
|||||||
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
|
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: error\n")
|
||||||
|
require.Contains(t, body, "data: [DONE]\n\n")
|
||||||
|
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
require.GreaterOrEqual(t, len(lines), 2)
|
||||||
|
require.Equal(t, "event: error", lines[0])
|
||||||
|
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
|
||||||
|
errObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errObj["type"])
|
||||||
|
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
|
||||||
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
|
sourceHeaders := http.Header{}
|
||||||
|
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||||
|
|
||||||
|
err := svc.handleSoraRequestError(
|
||||||
|
context.Background(),
|
||||||
|
&Account{ID: 1, Platform: PlatformSora},
|
||||||
|
&SoraUpstreamError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Message: "forbidden",
|
||||||
|
Headers: sourceHeaders,
|
||||||
|
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
|
||||||
|
},
|
||||||
|
"sora2-landscape-10s",
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.NotNil(t, failoverErr.ResponseHeaders)
|
||||||
|
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||||
|
|
||||||
|
sourceHeaders.Set("cf-ray", "mutated-after-return")
|
||||||
|
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||||
|
require.True(t, svc.shouldFailoverUpstreamError(404))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
require.True(t, svc.shouldFailoverUpstreamError(429))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
require.True(t, svc.shouldFailoverUpstreamError(500))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
require.True(t, svc.shouldFailoverUpstreamError(502))
|
||||||
@@ -257,3 +508,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
|||||||
require.NotEmpty(t, data)
|
require.NotEmpty(t, data)
|
||||||
require.Contains(t, filename, ".png")
|
require.Contains(t, filename, ".png")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
|
||||||
|
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
||||||
|
body := map[string]any{
|
||||||
|
"watermark_free": float64(1),
|
||||||
|
"watermark_fallback_on_failure": float64(0),
|
||||||
|
}
|
||||||
|
opts := parseSoraWatermarkOptions(body)
|
||||||
|
require.True(t, opts.Enabled)
|
||||||
|
require.False(t, opts.FallbackOnFailure)
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
|
|||||||
Model string
|
Model string
|
||||||
Size string
|
Size string
|
||||||
RequirePro bool
|
RequirePro bool
|
||||||
|
// Prompt-enhance 专用参数
|
||||||
|
ExpansionLevel string
|
||||||
|
DurationS int
|
||||||
}
|
}
|
||||||
|
|
||||||
var soraModelConfigs = map[string]SoraModelConfig{
|
var soraModelConfigs = map[string]SoraModelConfig{
|
||||||
@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
|
|||||||
RequirePro: true,
|
RequirePro: true,
|
||||||
},
|
},
|
||||||
"prompt-enhance-short-10s": {
|
"prompt-enhance-short-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-short-15s": {
|
"prompt-enhance-short-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-short-20s": {
|
"prompt-enhance-short-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-10s": {
|
"prompt-enhance-medium-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-15s": {
|
"prompt-enhance-medium-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-20s": {
|
"prompt-enhance-medium-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-10s": {
|
"prompt-enhance-long-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-15s": {
|
"prompt-enhance-long-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-20s": {
|
"prompt-enhance-long-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
266
backend/internal/service/sora_request_guard.go
Normal file
266
backend/internal/service/sora_request_guard.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type soraChallengeCooldownEntry struct {
|
||||||
|
Until time.Time
|
||||||
|
StatusCode int
|
||||||
|
CFRay string
|
||||||
|
ConsecutiveChallenges int
|
||||||
|
LastChallengeAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraSidecarSessionEntry struct {
|
||||||
|
SessionKey string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
LastUsedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return 900
|
||||||
|
}
|
||||||
|
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
|
||||||
|
if cooldown <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return cooldown
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||||
|
if cooldownSeconds <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
key := soraAccountProxyKey(account, proxyURL)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
c.challengeCooldownMu.RLock()
|
||||||
|
entry, ok := c.challengeCooldowns[key]
|
||||||
|
c.challengeCooldownMu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !entry.Until.After(now) {
|
||||||
|
c.challengeCooldownMu.Lock()
|
||||||
|
delete(c.challengeCooldowns, key)
|
||||||
|
c.challengeCooldownMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
|
||||||
|
if remaining < 1 {
|
||||||
|
remaining = 1
|
||||||
|
}
|
||||||
|
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
|
||||||
|
if entry.ConsecutiveChallenges > 1 {
|
||||||
|
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
|
||||||
|
}
|
||||||
|
if entry.CFRay != "" {
|
||||||
|
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
|
||||||
|
}
|
||||||
|
return &SoraUpstreamError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Message: message,
|
||||||
|
Headers: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||||
|
if cooldownSeconds <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := soraAccountProxyKey(account, proxyURL)
|
||||||
|
now := time.Now()
|
||||||
|
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
|
||||||
|
|
||||||
|
c.challengeCooldownMu.Lock()
|
||||||
|
c.cleanupExpiredChallengeCooldownsLocked(now)
|
||||||
|
|
||||||
|
streak := 1
|
||||||
|
existing, ok := c.challengeCooldowns[key]
|
||||||
|
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
|
||||||
|
streak = existing.ConsecutiveChallenges + 1
|
||||||
|
}
|
||||||
|
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
|
||||||
|
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
|
||||||
|
if ok && existing.Until.After(until) {
|
||||||
|
until = existing.Until
|
||||||
|
if existing.ConsecutiveChallenges > streak {
|
||||||
|
streak = existing.ConsecutiveChallenges
|
||||||
|
}
|
||||||
|
if cfRay == "" {
|
||||||
|
cfRay = existing.CFRay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
|
||||||
|
Until: until,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
CFRay: cfRay,
|
||||||
|
ConsecutiveChallenges: streak,
|
||||||
|
LastChallengeAt: now,
|
||||||
|
}
|
||||||
|
c.challengeCooldownMu.Unlock()
|
||||||
|
|
||||||
|
if c.debugEnabled() {
|
||||||
|
remain := int(math.Ceil(until.Sub(now).Seconds()))
|
||||||
|
if remain < 0 {
|
||||||
|
remain = 0
|
||||||
|
}
|
||||||
|
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
|
||||||
|
if baseSeconds <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if streak < 1 {
|
||||||
|
streak = 1
|
||||||
|
}
|
||||||
|
multiplier := streak
|
||||||
|
if multiplier > 4 {
|
||||||
|
multiplier = 4
|
||||||
|
}
|
||||||
|
cooldown := baseSeconds * multiplier
|
||||||
|
if cooldown > 3600 {
|
||||||
|
cooldown = 3600
|
||||||
|
}
|
||||||
|
return cooldown
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := soraAccountProxyKey(account, proxyURL)
|
||||||
|
c.challengeCooldownMu.Lock()
|
||||||
|
_, existed := c.challengeCooldowns[key]
|
||||||
|
if existed {
|
||||||
|
delete(c.challengeCooldowns, key)
|
||||||
|
}
|
||||||
|
c.challengeCooldownMu.Unlock()
|
||||||
|
if existed && c.debugEnabled() {
|
||||||
|
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
|
||||||
|
if c == nil || !c.sidecarSessionReuseEnabled() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
key := soraAccountProxyKey(account, proxyURL)
|
||||||
|
now := time.Now()
|
||||||
|
ttlSeconds := c.sidecarSessionTTLSeconds()
|
||||||
|
|
||||||
|
c.sidecarSessionMu.Lock()
|
||||||
|
defer c.sidecarSessionMu.Unlock()
|
||||||
|
c.cleanupExpiredSidecarSessionsLocked(now)
|
||||||
|
if existing, exists := c.sidecarSessions[key]; exists {
|
||||||
|
existing.LastUsedAt = now
|
||||||
|
c.sidecarSessions[key] = existing
|
||||||
|
return existing.SessionKey
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
|
||||||
|
if ttlSeconds <= 0 {
|
||||||
|
expiresAt = now.Add(365 * 24 * time.Hour)
|
||||||
|
}
|
||||||
|
newEntry := soraSidecarSessionEntry{
|
||||||
|
SessionKey: "sora-" + uuid.NewString(),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
LastUsedAt: now,
|
||||||
|
}
|
||||||
|
c.sidecarSessions[key] = newEntry
|
||||||
|
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
|
||||||
|
}
|
||||||
|
return newEntry.SessionKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
|
||||||
|
if c == nil || len(c.challengeCooldowns) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, entry := range c.challengeCooldowns {
|
||||||
|
if !entry.Until.After(now) {
|
||||||
|
delete(c.challengeCooldowns, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
|
||||||
|
if c == nil || len(c.sidecarSessions) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, entry := range c.sidecarSessions {
|
||||||
|
if !entry.ExpiresAt.After(now) {
|
||||||
|
delete(c.sidecarSessions, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraAccountProxyKey(account *Account, proxyURL string) string {
|
||||||
|
accountID := int64(0)
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSoraProxyKey(proxyURL string) string {
|
||||||
|
raw := strings.TrimSpace(proxyURL)
|
||||||
|
if raw == "" {
|
||||||
|
return "direct"
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return strings.ToLower(raw)
|
||||||
|
}
|
||||||
|
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||||
|
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||||
|
port := strings.TrimSpace(parsed.Port())
|
||||||
|
if host == "" {
|
||||||
|
return strings.ToLower(raw)
|
||||||
|
}
|
||||||
|
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
|
||||||
|
port = ""
|
||||||
|
}
|
||||||
|
if port != "" {
|
||||||
|
host = host + ":" + port
|
||||||
|
}
|
||||||
|
if scheme == "" {
|
||||||
|
scheme = "proxy"
|
||||||
|
}
|
||||||
|
return scheme + "://" + host
|
||||||
|
}
|
||||||
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
|
|||||||
stopCh: make(chan struct{}),
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||||
|
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||||
|
|
||||||
// 注册平台特定的刷新器
|
// 注册平台特定的刷新器
|
||||||
s.refreshers = []TokenRefresher{
|
s.refreshers = []TokenRefresher{
|
||||||
NewClaudeTokenRefresher(oauthService),
|
NewClaudeTokenRefresher(oauthService),
|
||||||
NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
|
openAIRefresher,
|
||||||
NewGeminiTokenRefresher(geminiOAuthService),
|
NewGeminiTokenRefresher(geminiOAuthService),
|
||||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
|
|||||||
openaiOAuthService *OpenAIOAuthService
|
openaiOAuthService *OpenAIOAuthService
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||||
|
syncLinkedSora bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||||
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
r.soraAccountRepo = repo
|
r.soraAccountRepo = repo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
|
||||||
|
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
|
||||||
|
r.syncLinkedSora = enabled
|
||||||
|
}
|
||||||
|
|
||||||
// CanRefresh 检查是否能处理此账号
|
// CanRefresh 检查是否能处理此账号
|
||||||
// 只处理 openai 平台的 oauth 类型账号
|
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
|
||||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
|
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||||
account.Type == AccountTypeOAuth
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NeedsRefresh 检查token是否需要刷新
|
// NeedsRefresh 检查token是否需要刷新
|
||||||
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||||
if r.accountRepo != nil {
|
if r.accountRepo != nil && r.syncLinkedSora {
|
||||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
|
||||||
|
refresher := &OpenAITokenRefresher{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
platform string
|
||||||
|
accType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai oauth - can refresh",
|
||||||
|
platform: PlatformOpenAI,
|
||||||
|
accType: AccountTypeOAuth,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sora oauth - cannot refresh directly",
|
||||||
|
platform: PlatformSora,
|
||||||
|
accType: AccountTypeOAuth,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai apikey - cannot refresh",
|
||||||
|
platform: PlatformOpenAI,
|
||||||
|
accType: AccountTypeAPIKey,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: tt.platform,
|
||||||
|
Type: tt.accType,
|
||||||
|
}
|
||||||
|
require.Equal(t, tt.want, refresher.CanRefresh(account))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ type UsageLog struct {
|
|||||||
CacheCreationTokens int
|
CacheCreationTokens int
|
||||||
CacheReadTokens int
|
CacheReadTokens int
|
||||||
|
|
||||||
CacheCreation5mTokens int
|
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||||
CacheCreation1hTokens int
|
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||||
|
|
||||||
InputCost float64
|
InputCost float64
|
||||||
OutputCost float64
|
OutputCost float64
|
||||||
@@ -46,6 +46,9 @@ type UsageLog struct {
|
|||||||
UserAgent *string
|
UserAgent *string
|
||||||
IPAddress *string
|
IPAddress *string
|
||||||
|
|
||||||
|
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||||
|
CacheTTLOverridden bool
|
||||||
|
|
||||||
// 图片生成字段
|
// 图片生成字段
|
||||||
ImageCount int
|
ImageCount int
|
||||||
ImageSize *string
|
ImageSize *string
|
||||||
|
|||||||
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
|||||||
return NewSoraMediaStorage(cfg)
|
return NewSoraMediaStorage(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProvideSoraDirectClient(
|
||||||
|
cfg *config.Config,
|
||||||
|
httpUpstream HTTPUpstream,
|
||||||
|
tokenProvider *OpenAITokenProvider,
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
soraAccountRepo SoraAccountRepository,
|
||||||
|
) *SoraDirectClient {
|
||||||
|
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||||
|
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||||
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||||
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
ProvideSoraMediaStorage,
|
ProvideSoraMediaStorage,
|
||||||
ProvideSoraMediaCleanupService,
|
ProvideSoraMediaCleanupService,
|
||||||
NewSoraDirectClient,
|
ProvideSoraDirectClient,
|
||||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||||
NewSoraGatewayService,
|
NewSoraGatewayService,
|
||||||
NewOpenAIGatewayService,
|
NewOpenAIGatewayService,
|
||||||
|
|||||||
170
backend/internal/util/soraerror/soraerror.go
Normal file
170
backend/internal/util/soraerror/soraerror.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package soraerror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||||
|
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||||
|
htmlChallenge = []string{
|
||||||
|
"window._cf_chl_opt",
|
||||||
|
"just a moment",
|
||||||
|
"enable javascript and cookies to continue",
|
||||||
|
"__cf_chl_",
|
||||||
|
"challenge-platform",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
|
||||||
|
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||||
|
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
preview := strings.ToLower(TruncateBody(body, 4096))
|
||||||
|
for _, marker := range htmlChallenge {
|
||||||
|
if strings.Contains(preview, marker) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := ""
|
||||||
|
if headers != nil {
|
||||||
|
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||||
|
}
|
||||||
|
if strings.Contains(contentType, "text/html") &&
|
||||||
|
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||||
|
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
|
||||||
|
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
|
||||||
|
if headers != nil {
|
||||||
|
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
preview := TruncateBody(body, 8192)
|
||||||
|
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||||
|
return strings.TrimSpace(matches[1])
|
||||||
|
}
|
||||||
|
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||||
|
return strings.TrimSpace(matches[1])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatCloudflareChallengeMessage appends cf-ray info when available.
|
||||||
|
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||||
|
rayID := ExtractCloudflareRayID(headers, body)
|
||||||
|
if rayID == "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
|
||||||
|
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||||
|
trimmed := strings.TrimSpace(string(body))
|
||||||
|
if trimmed == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if !json.Valid([]byte(trimmed)) {
|
||||||
|
return "", truncateMessage(trimmed, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
|
||||||
|
return "", truncateMessage(trimmed, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := firstNonEmpty(
|
||||||
|
extractNestedString(payload, "error", "code"),
|
||||||
|
extractRootString(payload, "code"),
|
||||||
|
)
|
||||||
|
message := firstNonEmpty(
|
||||||
|
extractNestedString(payload, "error", "message"),
|
||||||
|
extractRootString(payload, "message"),
|
||||||
|
extractNestedString(payload, "error", "detail"),
|
||||||
|
extractRootString(payload, "detail"),
|
||||||
|
)
|
||||||
|
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TruncateBody truncates body text for logging/inspection.
|
||||||
|
func TruncateBody(body []byte, max int) string {
|
||||||
|
if max <= 0 {
|
||||||
|
max = 512
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(string(body))
|
||||||
|
if len(raw) <= max {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return raw[:max] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateMessage(s string, max int) string {
|
||||||
|
if max <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(s) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:max] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, v := range values {
|
||||||
|
if strings.TrimSpace(v) != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractRootString(m map[string]any, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, _ := v.(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractNestedString(m map[string]any, parent, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
node, ok := m[parent]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
child, ok := node.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, _ := child[key].(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package soraerror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsCloudflareChallengeResponse(t *testing.T) {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("cf-mitigated", "challenge")
|
||||||
|
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
|
||||||
|
|
||||||
|
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
|
||||||
|
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCloudflareRayID(t *testing.T) {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||||
|
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
|
||||||
|
|
||||||
|
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||||
|
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
|
||||||
|
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
|
||||||
|
require.Equal(t, "cf_shield_429", code)
|
||||||
|
require.Equal(t, "rate limited", msg)
|
||||||
|
|
||||||
|
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
|
||||||
|
require.Equal(t, "unsupported_country_code", code)
|
||||||
|
require.Equal(t, "not available", msg)
|
||||||
|
|
||||||
|
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
|
||||||
|
require.Equal(t, "", code)
|
||||||
|
require.Equal(t, "plain text", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatCloudflareChallengeMessage(t *testing.T) {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||||
|
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
|
||||||
|
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
|
||||||
|
}
|
||||||
@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
|
|||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
|
strings.HasPrefix(path, "/sora/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
path == "/health" ||
|
path == "/health" ||
|
||||||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
|
strings.HasPrefix(path, "/sora/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
path == "/health" ||
|
path == "/health" ||
|
||||||
|
|||||||
@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
|
|||||||
"/api/v1/users",
|
"/api/v1/users",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
"/v1beta/chat",
|
"/v1beta/chat",
|
||||||
|
"/sora/v1/models",
|
||||||
"/antigravity/test",
|
"/antigravity/test",
|
||||||
"/setup/init",
|
"/setup/init",
|
||||||
"/health",
|
"/health",
|
||||||
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
|
|||||||
"/api/users",
|
"/api/users",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
"/v1beta/chat",
|
"/v1beta/chat",
|
||||||
|
"/sora/v1/models",
|
||||||
"/antigravity/test",
|
"/antigravity/test",
|
||||||
"/setup/init",
|
"/setup/init",
|
||||||
"/health",
|
"/health",
|
||||||
|
|||||||
44
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
44
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
-- Drop legacy cache token columns that lack the underscore separator.
|
||||||
|
-- These were created by GORM's automatic snake_case conversion:
|
||||||
|
-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect)
|
||||||
|
-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect)
|
||||||
|
--
|
||||||
|
-- The canonical columns are:
|
||||||
|
-- cache_creation_5m_tokens (defined in 001_init.sql)
|
||||||
|
-- cache_creation_1h_tokens (defined in 001_init.sql)
|
||||||
|
--
|
||||||
|
-- Migration 009 already copied data from legacy → canonical columns.
|
||||||
|
-- But upgraded instances may still have post-009 writes in legacy columns.
|
||||||
|
-- Backfill once more before dropping to prevent data loss.
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'usage_logs'
|
||||||
|
AND column_name = 'cache_creation5m_tokens'
|
||||||
|
) THEN
|
||||||
|
UPDATE usage_logs
|
||||||
|
SET cache_creation_5m_tokens = cache_creation5m_tokens
|
||||||
|
WHERE cache_creation_5m_tokens = 0
|
||||||
|
AND cache_creation5m_tokens <> 0;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'usage_logs'
|
||||||
|
AND column_name = 'cache_creation1h_tokens'
|
||||||
|
) THEN
|
||||||
|
UPDATE usage_logs
|
||||||
|
SET cache_creation_1h_tokens = cache_creation1h_tokens
|
||||||
|
WHERE cache_creation_1h_tokens = 0
|
||||||
|
AND cache_creation1h_tokens <> 0;
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens;
|
||||||
|
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens;
|
||||||
2
backend/migrations/055_add_cache_ttl_overridden.sql
Normal file
2
backend/migrations/055_add_cache_ttl_overridden.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account.
|
||||||
|
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
@@ -374,6 +374,9 @@ sora:
|
|||||||
# Max retries for upstream requests
|
# Max retries for upstream requests
|
||||||
# 上游请求最大重试次数
|
# 上游请求最大重试次数
|
||||||
max_retries: 3
|
max_retries: 3
|
||||||
|
# Account+proxy cooldown window after Cloudflare challenge (seconds, 0 to disable)
|
||||||
|
# Cloudflare challenge 后按账号+代理冷却窗口(秒,0 表示关闭)
|
||||||
|
cloudflare_challenge_cooldown_seconds: 900
|
||||||
# Poll interval (seconds)
|
# Poll interval (seconds)
|
||||||
# 轮询间隔(秒)
|
# 轮询间隔(秒)
|
||||||
poll_interval_seconds: 2
|
poll_interval_seconds: 2
|
||||||
@@ -388,7 +391,11 @@ sora:
|
|||||||
recent_task_limit_max: 200
|
recent_task_limit_max: 200
|
||||||
# Enable debug logs for Sora upstream requests
|
# Enable debug logs for Sora upstream requests
|
||||||
# 启用 Sora 直连调试日志
|
# 启用 Sora 直连调试日志
|
||||||
|
# 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏
|
||||||
debug: false
|
debug: false
|
||||||
|
# Allow Sora client to fetch token via OpenAI token provider
|
||||||
|
# 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路)
|
||||||
|
use_openai_token_provider: false
|
||||||
# Optional custom headers (key-value)
|
# Optional custom headers (key-value)
|
||||||
# 额外请求头(键值对)
|
# 额外请求头(键值对)
|
||||||
headers: {}
|
headers: {}
|
||||||
@@ -398,6 +405,27 @@ sora:
|
|||||||
# Disable TLS fingerprint for Sora upstream
|
# Disable TLS fingerprint for Sora upstream
|
||||||
# 关闭 Sora 上游 TLS 指纹伪装
|
# 关闭 Sora 上游 TLS 指纹伪装
|
||||||
disable_tls_fingerprint: false
|
disable_tls_fingerprint: false
|
||||||
|
# curl_cffi sidecar for Sora only (required)
|
||||||
|
# 仅 Sora 链路使用的 curl_cffi sidecar(必需)
|
||||||
|
curl_cffi_sidecar:
|
||||||
|
# Sora 强制通过 sidecar 请求,必须启用
|
||||||
|
# Sora is forced to use sidecar only; keep enabled=true
|
||||||
|
enabled: true
|
||||||
|
# Sidecar base URL (default endpoint: /request)
|
||||||
|
# sidecar 基础地址(默认请求端点:/request)
|
||||||
|
base_url: "http://sora-curl-cffi-sidecar:8080"
|
||||||
|
# curl_cffi impersonate profile, e.g. chrome131/chrome124/safari18_0
|
||||||
|
# curl_cffi 指纹伪装 profile,例如 chrome131/chrome124/safari18_0
|
||||||
|
impersonate: "chrome131"
|
||||||
|
# Sidecar request timeout (seconds)
|
||||||
|
# sidecar 请求超时(秒)
|
||||||
|
timeout_seconds: 60
|
||||||
|
# Reuse session key per account+proxy to let sidecar persist cookies/session
|
||||||
|
# 按账号+代理复用 session key,让 sidecar 持久化 cookies/session
|
||||||
|
session_reuse_enabled: true
|
||||||
|
# Session TTL in sidecar (seconds)
|
||||||
|
# sidecar 会话 TTL(秒)
|
||||||
|
session_ttl_seconds: 3600
|
||||||
storage:
|
storage:
|
||||||
# Storage type (local only for now)
|
# Storage type (local only for now)
|
||||||
# 存储类型(首发仅支持 local)
|
# 存储类型(首发仅支持 local)
|
||||||
@@ -431,6 +459,13 @@ sora:
|
|||||||
# Cron 调度表达式
|
# Cron 调度表达式
|
||||||
schedule: "0 3 * * *"
|
schedule: "0 3 * * *"
|
||||||
|
|
||||||
|
# Token refresh behavior
|
||||||
|
# token 刷新行为控制
|
||||||
|
token_refresh:
|
||||||
|
# Whether OpenAI refresh flow is allowed to sync linked Sora accounts
|
||||||
|
# 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token
|
||||||
|
sync_linked_sora_accounts: false
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# API Key Auth Cache Configuration
|
# API Key Auth Cache Configuration
|
||||||
# API Key 认证缓存配置
|
# API Key 认证缓存配置
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ services:
|
|||||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||||
|
- PGDATA=/var/lib/postgresql/data
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
- TZ=${TZ:-Asia/Shanghai}
|
||||||
networks:
|
networks:
|
||||||
- sub2api-network
|
- sub2api-network
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export async function list(
|
|||||||
platform?: string
|
platform?: string
|
||||||
type?: string
|
type?: string
|
||||||
status?: string
|
status?: string
|
||||||
|
group?: string
|
||||||
search?: string
|
search?: string
|
||||||
},
|
},
|
||||||
options?: {
|
options?: {
|
||||||
@@ -271,7 +272,7 @@ export async function generateAuthUrl(
|
|||||||
*/
|
*/
|
||||||
export async function exchangeCode(
|
export async function exchangeCode(
|
||||||
endpoint: string,
|
endpoint: string,
|
||||||
exchangeData: { session_id: string; code: string; proxy_id?: number }
|
exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number }
|
||||||
): Promise<Record<string, unknown>> {
|
): Promise<Record<string, unknown>> {
|
||||||
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
|
||||||
return data
|
return data
|
||||||
@@ -493,7 +494,8 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
|
|||||||
*/
|
*/
|
||||||
export async function refreshOpenAIToken(
|
export async function refreshOpenAIToken(
|
||||||
refreshToken: string,
|
refreshToken: string,
|
||||||
proxyId?: number | null
|
proxyId?: number | null,
|
||||||
|
endpoint: string = '/admin/openai/refresh-token'
|
||||||
): Promise<Record<string, unknown>> {
|
): Promise<Record<string, unknown>> {
|
||||||
const payload: { refresh_token: string; proxy_id?: number } = {
|
const payload: { refresh_token: string; proxy_id?: number } = {
|
||||||
refresh_token: refreshToken
|
refresh_token: refreshToken
|
||||||
@@ -501,7 +503,29 @@ export async function refreshOpenAIToken(
|
|||||||
if (proxyId) {
|
if (proxyId) {
|
||||||
payload.proxy_id = proxyId
|
payload.proxy_id = proxyId
|
||||||
}
|
}
|
||||||
const { data } = await apiClient.post<Record<string, unknown>>('/admin/openai/refresh-token', payload)
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate Sora session token and exchange to access token
|
||||||
|
* @param sessionToken - Sora session token
|
||||||
|
* @param proxyId - Optional proxy ID
|
||||||
|
* @param endpoint - API endpoint path
|
||||||
|
* @returns Token information including access_token
|
||||||
|
*/
|
||||||
|
export async function validateSoraSessionToken(
|
||||||
|
sessionToken: string,
|
||||||
|
proxyId?: number | null,
|
||||||
|
endpoint: string = '/admin/sora/st2at'
|
||||||
|
): Promise<Record<string, unknown>> {
|
||||||
|
const payload: { session_token: string; proxy_id?: number } = {
|
||||||
|
session_token: sessionToken
|
||||||
|
}
|
||||||
|
if (proxyId) {
|
||||||
|
payload.proxy_id = proxyId
|
||||||
|
}
|
||||||
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,6 +551,7 @@ export const accountsAPI = {
|
|||||||
generateAuthUrl,
|
generateAuthUrl,
|
||||||
exchangeCode,
|
exchangeCode,
|
||||||
refreshOpenAIToken,
|
refreshOpenAIToken,
|
||||||
|
validateSoraSessionToken,
|
||||||
batchCreate,
|
batchCreate,
|
||||||
batchUpdateCredentials,
|
batchUpdateCredentials,
|
||||||
bulkUpdate,
|
bulkUpdate,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { apiClient } from '../client'
|
|||||||
import type {
|
import type {
|
||||||
Proxy,
|
Proxy,
|
||||||
ProxyAccountSummary,
|
ProxyAccountSummary,
|
||||||
|
ProxyQualityCheckResult,
|
||||||
CreateProxyRequest,
|
CreateProxyRequest,
|
||||||
UpdateProxyRequest,
|
UpdateProxyRequest,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
@@ -143,6 +144,16 @@ export async function testProxy(id: number): Promise<{
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check proxy quality across common AI targets
|
||||||
|
* @param id - Proxy ID
|
||||||
|
* @returns Quality check result
|
||||||
|
*/
|
||||||
|
export async function checkProxyQuality(id: number): Promise<ProxyQualityCheckResult> {
|
||||||
|
const { data } = await apiClient.post<ProxyQualityCheckResult>(`/admin/proxies/${id}/quality-check`)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get proxy usage statistics
|
* Get proxy usage statistics
|
||||||
* @param id - Proxy ID
|
* @param id - Proxy ID
|
||||||
@@ -248,6 +259,7 @@ export const proxiesAPI = {
|
|||||||
delete: deleteProxy,
|
delete: deleteProxy,
|
||||||
toggleStatus,
|
toggleStatus,
|
||||||
testProxy,
|
testProxy,
|
||||||
|
checkProxyQuality,
|
||||||
getStats,
|
getStats,
|
||||||
getProxyAccounts,
|
getProxyAccounts,
|
||||||
batchCreate,
|
batchCreate,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
>
|
>
|
||||||
<div class="mb-2 flex items-center justify-between">
|
<div class="mb-2 flex items-center justify-between">
|
||||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||||
{{ t('admin.accounts.allGroups', { count: groups.length }) }}
|
{{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
|
||||||
</span>
|
</span>
|
||||||
<button
|
<button
|
||||||
@click="showPopover = false"
|
@click="showPopover = false"
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="space-y-1.5">
|
<div v-if="!isSoraAccount" class="space-y-1.5">
|
||||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
{{ t('admin.accounts.selectTestModel') }}
|
{{ t('admin.accounts.selectTestModel') }}
|
||||||
</label>
|
</label>
|
||||||
@@ -54,6 +54,12 @@
|
|||||||
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
|
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-else
|
||||||
|
class="rounded-lg border border-blue-200 bg-blue-50 px-3 py-2 text-xs text-blue-700 dark:border-blue-700 dark:bg-blue-900/20 dark:text-blue-300"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.soraTestHint') }}
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Terminal Output -->
|
<!-- Terminal Output -->
|
||||||
<div class="group relative">
|
<div class="group relative">
|
||||||
@@ -135,12 +141,12 @@
|
|||||||
<div class="flex items-center gap-3">
|
<div class="flex items-center gap-3">
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="cpu" size="sm" :stroke-width="2" />
|
<Icon name="cpu" size="sm" :stroke-width="2" />
|
||||||
{{ t('admin.accounts.testModel') }}
|
{{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="chatBubble" size="sm" :stroke-width="2" />
|
<Icon name="chatBubble" size="sm" :stroke-width="2" />
|
||||||
{{ t('admin.accounts.testPrompt') }}
|
{{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -156,10 +162,10 @@
|
|||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
@click="startTest"
|
@click="startTest"
|
||||||
:disabled="status === 'connecting' || !selectedModelId"
|
:disabled="status === 'connecting' || (!isSoraAccount && !selectedModelId)"
|
||||||
:class="[
|
:class="[
|
||||||
'flex items-center gap-2 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
'flex items-center gap-2 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||||
status === 'connecting' || !selectedModelId
|
status === 'connecting' || (!isSoraAccount && !selectedModelId)
|
||||||
? 'cursor-not-allowed bg-primary-400 text-white'
|
? 'cursor-not-allowed bg-primary-400 text-white'
|
||||||
: status === 'success'
|
: status === 'success'
|
||||||
? 'bg-green-500 text-white hover:bg-green-600'
|
? 'bg-green-500 text-white hover:bg-green-600'
|
||||||
@@ -232,7 +238,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch, nextTick } from 'vue'
|
import { computed, ref, watch, nextTick } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
@@ -267,6 +273,7 @@ const availableModels = ref<ClaudeModel[]>([])
|
|||||||
const selectedModelId = ref('')
|
const selectedModelId = ref('')
|
||||||
const loadingModels = ref(false)
|
const loadingModels = ref(false)
|
||||||
let eventSource: EventSource | null = null
|
let eventSource: EventSource | null = null
|
||||||
|
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
||||||
|
|
||||||
// Load available models when modal opens
|
// Load available models when modal opens
|
||||||
watch(
|
watch(
|
||||||
@@ -283,6 +290,12 @@ watch(
|
|||||||
|
|
||||||
const loadAvailableModels = async () => {
|
const loadAvailableModels = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
if (props.account.platform === 'sora') {
|
||||||
|
availableModels.value = []
|
||||||
|
selectedModelId.value = ''
|
||||||
|
loadingModels.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
loadingModels.value = true
|
loadingModels.value = true
|
||||||
selectedModelId.value = '' // Reset selection before loading
|
selectedModelId.value = '' // Reset selection before loading
|
||||||
@@ -350,7 +363,7 @@ const scrollToBottom = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const startTest = async () => {
|
const startTest = async () => {
|
||||||
if (!props.account || !selectedModelId.value) return
|
if (!props.account || (!isSoraAccount.value && !selectedModelId.value)) return
|
||||||
|
|
||||||
resetState()
|
resetState()
|
||||||
status.value = 'connecting'
|
status.value = 'connecting'
|
||||||
@@ -371,7 +384,9 @@ const startTest = async () => {
|
|||||||
Authorization: `Bearer ${localStorage.getItem('auth_token')}`,
|
Authorization: `Bearer ${localStorage.getItem('auth_token')}`,
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ model_id: selectedModelId.value })
|
body: JSON.stringify(
|
||||||
|
isSoraAccount.value ? {} : { model_id: selectedModelId.value }
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@@ -428,7 +443,10 @@ const handleEvent = (event: {
|
|||||||
if (event.model) {
|
if (event.model) {
|
||||||
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
||||||
}
|
}
|
||||||
addLine(t('admin.accounts.sendingTestMessage'), 'text-gray-400')
|
addLine(
|
||||||
|
isSoraAccount.value ? t('admin.accounts.soraTestingFlow') : t('admin.accounts.sendingTestMessage'),
|
||||||
|
'text-gray-400'
|
||||||
|
)
|
||||||
addLine('', 'text-gray-300')
|
addLine('', 'text-gray-300')
|
||||||
addLine(t('admin.accounts.response'), 'text-yellow-400')
|
addLine(t('admin.accounts.response'), 'text-yellow-400')
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -710,6 +710,7 @@ const groupIds = ref<number[]>([])
|
|||||||
// All models list (combined Anthropic + OpenAI)
|
// All models list (combined Anthropic + OpenAI)
|
||||||
const allModels = [
|
const allModels = [
|
||||||
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
||||||
|
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
||||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||||
@@ -757,6 +758,13 @@ const presetMappings = [
|
|||||||
color:
|
color:
|
||||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Sonnet 4.6',
|
||||||
|
from: 'claude-sonnet-4-6',
|
||||||
|
to: 'claude-sonnet-4-6',
|
||||||
|
color:
|
||||||
|
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'Opus->Sonnet',
|
label: 'Opus->Sonnet',
|
||||||
from: 'claude-opus-4-5-20251101',
|
from: 'claude-opus-4-5-20251101',
|
||||||
|
|||||||
@@ -109,6 +109,28 @@
|
|||||||
</svg>
|
</svg>
|
||||||
OpenAI
|
OpenAI
|
||||||
</button>
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="form.platform = 'sora'"
|
||||||
|
:class="[
|
||||||
|
'flex flex-1 items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
|
||||||
|
form.platform === 'sora'
|
||||||
|
? 'bg-white text-rose-600 shadow-sm dark:bg-dark-600 dark:text-rose-400'
|
||||||
|
: 'text-gray-600 hover:text-gray-900 dark:text-gray-400 dark:hover:text-gray-200'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||||
|
</svg>
|
||||||
|
Sora
|
||||||
|
</button>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="form.platform = 'gemini'"
|
@click="form.platform = 'gemini'"
|
||||||
@@ -150,6 +172,38 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Account Type Selection (Sora) -->
|
||||||
|
<div v-if="form.platform === 'sora'">
|
||||||
|
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||||
|
<div class="mt-2 grid grid-cols-1 gap-3" data-tour="account-form-type">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="accountCategory = 'oauth-based'"
|
||||||
|
:class="[
|
||||||
|
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||||
|
accountCategory === 'oauth-based'
|
||||||
|
? 'border-rose-500 bg-rose-50 dark:bg-rose-900/20'
|
||||||
|
: 'border-gray-200 hover:border-rose-300 dark:border-dark-600 dark:hover:border-rose-700'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
:class="[
|
||||||
|
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||||
|
accountCategory === 'oauth-based'
|
||||||
|
? 'bg-rose-500 text-white'
|
||||||
|
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<Icon name="key" size="sm" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span class="block text-sm font-medium text-gray-900 dark:text-white">OAuth</span>
|
||||||
|
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.chatgptOauth') }}</span>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Account Type Selection (Anthropic) -->
|
<!-- Account Type Selection (Anthropic) -->
|
||||||
<div v-if="form.platform === 'anthropic'">
|
<div v-if="form.platform === 'anthropic'">
|
||||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||||
@@ -1538,6 +1592,46 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Cache TTL Override -->
|
||||||
|
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.cacheTTLOverride.label') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="cacheTTLOverrideEnabled = !cacheTTLOverrideEnabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
cacheTTLOverrideEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
cacheTTLOverrideEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div v-if="cacheTTLOverrideEnabled" class="mt-3">
|
||||||
|
<label class="input-label text-xs">{{ t('admin.accounts.quotaControl.cacheTTLOverride.target') }}</label>
|
||||||
|
<select
|
||||||
|
v-model="cacheTTLOverrideTarget"
|
||||||
|
class="mt-1 block w-full rounded-md border border-gray-300 bg-white px-3 py-2 text-sm shadow-sm focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500 dark:border-dark-500 dark:bg-dark-700 dark:text-white"
|
||||||
|
>
|
||||||
|
<option value="5m">5m</option>
|
||||||
|
<option value="1h">1h</option>
|
||||||
|
</select>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
@@ -1707,32 +1801,6 @@
|
|||||||
|
|
||||||
<!-- Step 2: OAuth Authorization -->
|
<!-- Step 2: OAuth Authorization -->
|
||||||
<div v-else class="space-y-5">
|
<div v-else class="space-y-5">
|
||||||
<!-- 同时启用 Sora 开关 (仅 OpenAI OAuth) -->
|
|
||||||
<div v-if="form.platform === 'openai' && accountCategory === 'oauth-based'" class="mb-4">
|
|
||||||
<label class="flex items-center justify-between rounded-lg border border-gray-200 p-3 dark:border-dark-600">
|
|
||||||
<div class="flex items-center gap-3">
|
|
||||||
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-rose-100 text-rose-600 dark:bg-rose-900/30 dark:text-rose-400">
|
|
||||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
|
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
|
||||||
</svg>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
|
||||||
{{ t('admin.accounts.openai.enableSora') }}
|
|
||||||
</span>
|
|
||||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
|
||||||
{{ t('admin.accounts.openai.enableSoraHint') }}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<label :class="['switch', { 'switch-active': enableSoraOnOpenAIOAuth }]">
|
|
||||||
<input type="checkbox" v-model="enableSoraOnOpenAIOAuth" class="sr-only" />
|
|
||||||
<span class="switch-thumb"></span>
|
|
||||||
</label>
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<OAuthAuthorizationFlow
|
<OAuthAuthorizationFlow
|
||||||
ref="oauthFlowRef"
|
ref="oauthFlowRef"
|
||||||
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
|
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
|
||||||
@@ -1741,15 +1809,17 @@
|
|||||||
:loading="currentOAuthLoading"
|
:loading="currentOAuthLoading"
|
||||||
:error="currentOAuthError"
|
:error="currentOAuthError"
|
||||||
:show-help="form.platform === 'anthropic'"
|
:show-help="form.platform === 'anthropic'"
|
||||||
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
|
:show-proxy-warning="form.platform !== 'openai' && form.platform !== 'sora' && !!form.proxy_id"
|
||||||
:allow-multiple="form.platform === 'anthropic'"
|
:allow-multiple="form.platform === 'anthropic'"
|
||||||
:show-cookie-option="form.platform === 'anthropic'"
|
:show-cookie-option="form.platform === 'anthropic'"
|
||||||
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'"
|
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
|
||||||
|
:show-session-token-option="form.platform === 'sora'"
|
||||||
:platform="form.platform"
|
:platform="form.platform"
|
||||||
:show-project-id="geminiOAuthType === 'code_assist'"
|
:show-project-id="geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@validate-refresh-token="handleValidateRefreshToken"
|
@validate-refresh-token="handleValidateRefreshToken"
|
||||||
|
@validate-session-token="handleValidateSessionToken"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@@ -2108,6 +2178,7 @@ interface OAuthFlowExposed {
|
|||||||
projectId: string
|
projectId: string
|
||||||
sessionKey: string
|
sessionKey: string
|
||||||
refreshToken: string
|
refreshToken: string
|
||||||
|
sessionToken: string
|
||||||
inputMethod: AuthInputMethod
|
inputMethod: AuthInputMethod
|
||||||
reset: () => void
|
reset: () => void
|
||||||
}
|
}
|
||||||
@@ -2116,7 +2187,7 @@ const { t } = useI18n()
|
|||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
|
|
||||||
const oauthStepTitle = computed(() => {
|
const oauthStepTitle = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
|
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
|
||||||
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
|
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
|
||||||
return t('admin.accounts.oauth.title')
|
return t('admin.accounts.oauth.title')
|
||||||
@@ -2124,13 +2195,13 @@ const oauthStepTitle = computed(() => {
|
|||||||
|
|
||||||
// Platform-specific hints for API Key type
|
// Platform-specific hints for API Key type
|
||||||
const baseUrlHint = computed(() => {
|
const baseUrlHint = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
||||||
return t('admin.accounts.baseUrlHint')
|
return t('admin.accounts.baseUrlHint')
|
||||||
})
|
})
|
||||||
|
|
||||||
const apiKeyHint = computed(() => {
|
const apiKeyHint = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
|
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
|
||||||
return t('admin.accounts.apiKeyHint')
|
return t('admin.accounts.apiKeyHint')
|
||||||
})
|
})
|
||||||
@@ -2151,34 +2222,36 @@ const appStore = useAppStore()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const oauth = useAccountOAuth() // For Anthropic OAuth
|
const oauth = useAccountOAuth() // For Anthropic OAuth
|
||||||
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth
|
||||||
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
|
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
|
||||||
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
|
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
|
||||||
|
const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed: current OAuth state for template binding
|
// Computed: current OAuth state for template binding
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.authUrl.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
|
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
|
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
|
||||||
return oauth.authUrl.value
|
return oauth.authUrl.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.sessionId.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
|
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
|
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
|
||||||
return oauth.sessionId.value
|
return oauth.sessionId.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentOAuthLoading = computed(() => {
|
const currentOAuthLoading = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.loading.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.loading.value
|
if (form.platform === 'gemini') return geminiOAuth.loading.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
|
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
|
||||||
return oauth.loading.value
|
return oauth.loading.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentOAuthError = computed(() => {
|
const currentOAuthError = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.error.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.error.value
|
if (form.platform === 'gemini') return geminiOAuth.error.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.error.value
|
if (form.platform === 'antigravity') return antigravityOAuth.error.value
|
||||||
return oauth.error.value
|
return oauth.error.value
|
||||||
@@ -2217,7 +2290,6 @@ const interceptWarmupRequests = ref(false)
|
|||||||
const autoPauseOnExpired = ref(true)
|
const autoPauseOnExpired = ref(true)
|
||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora
|
|
||||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||||
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
||||||
const upstreamBaseUrl = ref('') // For upstream type: base URL
|
const upstreamBaseUrl = ref('') // For upstream type: base URL
|
||||||
@@ -2250,6 +2322,8 @@ const maxSessions = ref<number | null>(null)
|
|||||||
const sessionIdleTimeout = ref<number | null>(null)
|
const sessionIdleTimeout = ref<number | null>(null)
|
||||||
const tlsFingerprintEnabled = ref(false)
|
const tlsFingerprintEnabled = ref(false)
|
||||||
const sessionIdMaskingEnabled = ref(false)
|
const sessionIdMaskingEnabled = ref(false)
|
||||||
|
const cacheTTLOverrideEnabled = ref(false)
|
||||||
|
const cacheTTLOverrideTarget = ref<string>('5m')
|
||||||
|
|
||||||
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
|
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
|
||||||
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
|
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
|
||||||
@@ -2356,8 +2430,8 @@ const expiresAtInput = computed({
|
|||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value
|
return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value
|
||||||
}
|
}
|
||||||
if (form.platform === 'gemini') {
|
if (form.platform === 'gemini') {
|
||||||
return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value
|
return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value
|
||||||
@@ -2417,7 +2491,7 @@ watch(
|
|||||||
(newPlatform) => {
|
(newPlatform) => {
|
||||||
// Reset base URL based on platform
|
// Reset base URL based on platform
|
||||||
apiKeyBaseUrl.value =
|
apiKeyBaseUrl.value =
|
||||||
newPlatform === 'openai'
|
(newPlatform === 'openai' || newPlatform === 'sora')
|
||||||
? 'https://api.openai.com'
|
? 'https://api.openai.com'
|
||||||
: newPlatform === 'gemini'
|
: newPlatform === 'gemini'
|
||||||
? 'https://generativelanguage.googleapis.com'
|
? 'https://generativelanguage.googleapis.com'
|
||||||
@@ -2443,6 +2517,11 @@ watch(
|
|||||||
if (newPlatform !== 'anthropic') {
|
if (newPlatform !== 'anthropic') {
|
||||||
interceptWarmupRequests.value = false
|
interceptWarmupRequests.value = false
|
||||||
}
|
}
|
||||||
|
if (newPlatform === 'sora') {
|
||||||
|
accountCategory.value = 'oauth-based'
|
||||||
|
addMethod.value = 'oauth'
|
||||||
|
form.type = 'oauth'
|
||||||
|
}
|
||||||
if (newPlatform !== 'openai') {
|
if (newPlatform !== 'openai') {
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
@@ -2450,6 +2529,7 @@ watch(
|
|||||||
// Reset OAuth states
|
// Reset OAuth states
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
}
|
}
|
||||||
@@ -2711,7 +2791,6 @@ const resetForm = () => {
|
|||||||
autoPauseOnExpired.value = true
|
autoPauseOnExpired.value = true
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
enableSoraOnOpenAIOAuth.value = false
|
|
||||||
// Reset quota control state
|
// Reset quota control state
|
||||||
windowCostEnabled.value = false
|
windowCostEnabled.value = false
|
||||||
windowCostLimit.value = null
|
windowCostLimit.value = null
|
||||||
@@ -2721,6 +2800,8 @@ const resetForm = () => {
|
|||||||
sessionIdleTimeout.value = null
|
sessionIdleTimeout.value = null
|
||||||
tlsFingerprintEnabled.value = false
|
tlsFingerprintEnabled.value = false
|
||||||
sessionIdMaskingEnabled.value = false
|
sessionIdMaskingEnabled.value = false
|
||||||
|
cacheTTLOverrideEnabled.value = false
|
||||||
|
cacheTTLOverrideTarget.value = '5m'
|
||||||
antigravityAccountType.value = 'oauth'
|
antigravityAccountType.value = 'oauth'
|
||||||
upstreamBaseUrl.value = ''
|
upstreamBaseUrl.value = ''
|
||||||
upstreamApiKey.value = ''
|
upstreamApiKey.value = ''
|
||||||
@@ -2732,6 +2813,7 @@ const resetForm = () => {
|
|||||||
geminiTierAIStudio.value = 'aistudio_free'
|
geminiTierAIStudio.value = 'aistudio_free'
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -2763,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
|
|||||||
return Object.keys(extra).length > 0 ? extra : undefined
|
return Object.keys(extra).length > 0 ? extra : undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const buildSoraExtra = (
|
||||||
|
base?: Record<string, unknown>,
|
||||||
|
linkedOpenAIAccountId?: string | number
|
||||||
|
): Record<string, unknown> | undefined => {
|
||||||
|
const extra: Record<string, unknown> = { ...(base || {}) }
|
||||||
|
if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) {
|
||||||
|
const id = String(linkedOpenAIAccountId).trim()
|
||||||
|
if (id) {
|
||||||
|
extra.linked_openai_account_id = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete extra.openai_passthrough
|
||||||
|
delete extra.openai_oauth_passthrough
|
||||||
|
delete extra.codex_cli_only
|
||||||
|
return Object.keys(extra).length > 0 ? extra : undefined
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create account with mixed channel warning handling
|
// Helper function to create account with mixed channel warning handling
|
||||||
const doCreateAccount = async (payload: any) => {
|
const doCreateAccount = async (payload: any) => {
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
@@ -2878,7 +2977,7 @@ const handleSubmit = async () => {
|
|||||||
|
|
||||||
// Determine default base URL based on platform
|
// Determine default base URL based on platform
|
||||||
const defaultBaseUrl =
|
const defaultBaseUrl =
|
||||||
form.platform === 'openai'
|
(form.platform === 'openai' || form.platform === 'sora')
|
||||||
? 'https://api.openai.com'
|
? 'https://api.openai.com'
|
||||||
: form.platform === 'gemini'
|
: form.platform === 'gemini'
|
||||||
? 'https://generativelanguage.googleapis.com'
|
? 'https://generativelanguage.googleapis.com'
|
||||||
@@ -2930,14 +3029,15 @@ const goBackToBasicInfo = () => {
|
|||||||
step.value = 1
|
step.value = 1
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
await openaiOAuth.generateAuthUrl(form.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id)
|
||||||
} else if (form.platform === 'gemini') {
|
} else if (form.platform === 'gemini') {
|
||||||
await geminiOAuth.generateAuthUrl(
|
await geminiOAuth.generateAuthUrl(
|
||||||
form.proxy_id,
|
form.proxy_id,
|
||||||
@@ -2953,13 +3053,19 @@ const handleGenerateUrl = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleValidateRefreshToken = (rt: string) => {
|
const handleValidateRefreshToken = (rt: string) => {
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
handleOpenAIValidateRT(rt)
|
handleOpenAIValidateRT(rt)
|
||||||
} else if (form.platform === 'antigravity') {
|
} else if (form.platform === 'antigravity') {
|
||||||
handleAntigravityValidateRT(rt)
|
handleAntigravityValidateRT(rt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleValidateSessionToken = (sessionToken: string) => {
|
||||||
|
if (form.platform === 'sora') {
|
||||||
|
handleSoraValidateST(sessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||||
|
|
||||||
@@ -2995,100 +3101,101 @@ const createAccountAndFinish = async (
|
|||||||
|
|
||||||
// OpenAI OAuth 授权码兑换
|
// OpenAI OAuth 授权码兑换
|
||||||
const handleOpenAIExchange = async (authCode: string) => {
|
const handleOpenAIExchange = async (authCode: string) => {
|
||||||
if (!authCode.trim() || !openaiOAuth.sessionId.value) return
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
if (!authCode.trim() || !oauthClient.sessionId.value) return
|
||||||
|
|
||||||
openaiOAuth.loading.value = true
|
oauthClient.loading.value = true
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
openaiOAuth.sessionId.value,
|
oauthClient.sessionId.value,
|
||||||
|
stateToUse,
|
||||||
form.proxy_id
|
form.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
const extra = buildOpenAIExtra(oauthExtra)
|
const extra = buildOpenAIExtra(oauthExtra)
|
||||||
|
const shouldCreateOpenAI = form.platform === 'openai'
|
||||||
|
const shouldCreateSora = form.platform === 'sora'
|
||||||
|
|
||||||
// 应用临时不可调度配置
|
// 应用临时不可调度配置
|
||||||
if (!applyTempUnschedConfig(credentials)) {
|
if (!applyTempUnschedConfig(credentials)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 创建 OpenAI 账号
|
let openaiAccountId: string | number | undefined
|
||||||
const openaiAccount = await adminAPI.accounts.create({
|
|
||||||
name: form.name,
|
|
||||||
notes: form.notes,
|
|
||||||
platform: 'openai',
|
|
||||||
type: 'oauth',
|
|
||||||
credentials,
|
|
||||||
extra,
|
|
||||||
proxy_id: form.proxy_id,
|
|
||||||
concurrency: form.concurrency,
|
|
||||||
priority: form.priority,
|
|
||||||
rate_multiplier: form.rate_multiplier,
|
|
||||||
group_ids: form.group_ids,
|
|
||||||
expires_at: form.expires_at,
|
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
|
||||||
})
|
|
||||||
|
|
||||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
if (shouldCreateOpenAI) {
|
||||||
|
const openaiAccount = await adminAPI.accounts.create({
|
||||||
|
name: form.name,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'openai',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials,
|
||||||
|
extra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
openaiAccountId = openaiAccount.id
|
||||||
|
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 如果启用了 Sora,同时创建 Sora 账号
|
if (shouldCreateSora) {
|
||||||
if (enableSoraOnOpenAIOAuth.value) {
|
const soraCredentials = {
|
||||||
try {
|
access_token: credentials.access_token,
|
||||||
// Sora 使用相同的 OAuth credentials
|
refresh_token: credentials.refresh_token,
|
||||||
const soraCredentials = {
|
expires_at: credentials.expires_at
|
||||||
access_token: credentials.access_token,
|
|
||||||
refresh_token: credentials.refresh_token,
|
|
||||||
expires_at: credentials.expires_at
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立关联关系
|
|
||||||
const soraExtra: Record<string, unknown> = {
|
|
||||||
...(extra || {}),
|
|
||||||
linked_openai_account_id: String(openaiAccount.id)
|
|
||||||
}
|
|
||||||
delete soraExtra.openai_passthrough
|
|
||||||
delete soraExtra.openai_oauth_passthrough
|
|
||||||
|
|
||||||
await adminAPI.accounts.create({
|
|
||||||
name: `${form.name} (Sora)`,
|
|
||||||
notes: form.notes,
|
|
||||||
platform: 'sora',
|
|
||||||
type: 'oauth',
|
|
||||||
credentials: soraCredentials,
|
|
||||||
extra: soraExtra,
|
|
||||||
proxy_id: form.proxy_id,
|
|
||||||
concurrency: form.concurrency,
|
|
||||||
priority: form.priority,
|
|
||||||
rate_multiplier: form.rate_multiplier,
|
|
||||||
group_ids: form.group_ids,
|
|
||||||
expires_at: form.expires_at,
|
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
|
||||||
})
|
|
||||||
|
|
||||||
appStore.showSuccess(t('admin.accounts.soraAccountCreated'))
|
|
||||||
} catch (error: any) {
|
|
||||||
console.error('创建 Sora 账号失败:', error)
|
|
||||||
appStore.showWarning(t('admin.accounts.soraAccountFailed'))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name
|
||||||
|
const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: soraName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'sora',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials: soraCredentials,
|
||||||
|
extra: soraExtra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||||
}
|
}
|
||||||
|
|
||||||
emit('created')
|
emit('created')
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
} finally {
|
} finally {
|
||||||
openaiOAuth.loading.value = false
|
oauthClient.loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI 手动 RT 批量验证和创建
|
// OpenAI 手动 RT 批量验证和创建
|
||||||
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||||
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
if (!refreshTokenInput.trim()) return
|
if (!refreshTokenInput.trim()) return
|
||||||
|
|
||||||
// Parse multiple refresh tokens (one per line)
|
// Parse multiple refresh tokens (one per line)
|
||||||
@@ -3098,53 +3205,86 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
.filter((rt) => rt)
|
.filter((rt) => rt)
|
||||||
|
|
||||||
if (refreshTokens.length === 0) {
|
if (refreshTokens.length === 0) {
|
||||||
openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
|
oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiOAuth.loading.value = true
|
oauthClient.loading.value = true
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
let successCount = 0
|
let successCount = 0
|
||||||
let failedCount = 0
|
let failedCount = 0
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
|
const shouldCreateOpenAI = form.platform === 'openai'
|
||||||
|
const shouldCreateSora = form.platform === 'sora'
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (let i = 0; i < refreshTokens.length; i++) {
|
for (let i = 0; i < refreshTokens.length; i++) {
|
||||||
try {
|
try {
|
||||||
const tokenInfo = await openaiOAuth.validateRefreshToken(
|
const tokenInfo = await oauthClient.validateRefreshToken(
|
||||||
refreshTokens[i],
|
refreshTokens[i],
|
||||||
form.proxy_id
|
form.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) {
|
if (!tokenInfo) {
|
||||||
failedCount++
|
failedCount++
|
||||||
errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`)
|
errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
const extra = buildOpenAIExtra(oauthExtra)
|
const extra = buildOpenAIExtra(oauthExtra)
|
||||||
|
|
||||||
// Generate account name with index for batch
|
// Generate account name with index for batch
|
||||||
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
await adminAPI.accounts.create({
|
let openaiAccountId: string | number | undefined
|
||||||
name: accountName,
|
|
||||||
notes: form.notes,
|
if (shouldCreateOpenAI) {
|
||||||
platform: 'openai',
|
const openaiAccount = await adminAPI.accounts.create({
|
||||||
type: 'oauth',
|
name: accountName,
|
||||||
credentials,
|
notes: form.notes,
|
||||||
extra,
|
platform: 'openai',
|
||||||
proxy_id: form.proxy_id,
|
type: 'oauth',
|
||||||
concurrency: form.concurrency,
|
credentials,
|
||||||
priority: form.priority,
|
extra,
|
||||||
rate_multiplier: form.rate_multiplier,
|
proxy_id: form.proxy_id,
|
||||||
group_ids: form.group_ids,
|
concurrency: form.concurrency,
|
||||||
expires_at: form.expires_at,
|
priority: form.priority,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
rate_multiplier: form.rate_multiplier,
|
||||||
})
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
openaiAccountId = openaiAccount.id
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldCreateSora) {
|
||||||
|
const soraCredentials = {
|
||||||
|
access_token: credentials.access_token,
|
||||||
|
refresh_token: credentials.refresh_token,
|
||||||
|
expires_at: credentials.expires_at
|
||||||
|
}
|
||||||
|
const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName
|
||||||
|
const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: soraName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'sora',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials: soraCredentials,
|
||||||
|
extra: soraExtra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
successCount++
|
successCount++
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
failedCount++
|
failedCount++
|
||||||
@@ -3166,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
appStore.showWarning(
|
appStore.showWarning(
|
||||||
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||||
)
|
)
|
||||||
openaiOAuth.error.value = errors.join('\n')
|
oauthClient.error.value = errors.join('\n')
|
||||||
emit('created')
|
emit('created')
|
||||||
} else {
|
} else {
|
||||||
openaiOAuth.error.value = errors.join('\n')
|
oauthClient.error.value = errors.join('\n')
|
||||||
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
openaiOAuth.loading.value = false
|
oauthClient.loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sora 手动 ST 批量验证和创建
|
||||||
|
const handleSoraValidateST = async (sessionTokenInput: string) => {
|
||||||
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
if (!sessionTokenInput.trim()) return
|
||||||
|
|
||||||
|
const sessionTokens = sessionTokenInput
|
||||||
|
.split('\n')
|
||||||
|
.map((st) => st.trim())
|
||||||
|
.filter((st) => st)
|
||||||
|
|
||||||
|
if (sessionTokens.length === 0) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthClient.loading.value = true
|
||||||
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
|
let successCount = 0
|
||||||
|
let failedCount = 0
|
||||||
|
const errors: string[] = []
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (let i = 0; i < sessionTokens.length; i++) {
|
||||||
|
try {
|
||||||
|
const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id)
|
||||||
|
if (!tokenInfo) {
|
||||||
|
failedCount++
|
||||||
|
errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
|
||||||
|
oauthClient.error.value = ''
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
|
credentials.session_token = sessionTokens[i]
|
||||||
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
|
const soraExtra = buildSoraExtra(oauthExtra)
|
||||||
|
|
||||||
|
const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: accountName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'sora',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials,
|
||||||
|
extra: soraExtra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
successCount++
|
||||||
|
} catch (error: any) {
|
||||||
|
failedCount++
|
||||||
|
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
|
||||||
|
errors.push(`#${i + 1}: ${errMsg}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (successCount > 0 && failedCount === 0) {
|
||||||
|
appStore.showSuccess(
|
||||||
|
sessionTokens.length > 1
|
||||||
|
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
|
||||||
|
: t('admin.accounts.accountCreated')
|
||||||
|
)
|
||||||
|
emit('created')
|
||||||
|
handleClose()
|
||||||
|
} else if (successCount > 0 && failedCount > 0) {
|
||||||
|
appStore.showWarning(
|
||||||
|
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||||
|
)
|
||||||
|
oauthClient.error.value = errors.join('\n')
|
||||||
|
emit('created')
|
||||||
|
} else {
|
||||||
|
oauthClient.error.value = errors.join('\n')
|
||||||
|
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
oauthClient.loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3393,6 +3618,12 @@ const handleAnthropicExchange = async (authCode: string) => {
|
|||||||
extra.session_id_masking_enabled = true
|
extra.session_id_masking_enabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add cache TTL override settings
|
||||||
|
if (cacheTTLOverrideEnabled.value) {
|
||||||
|
extra.cache_ttl_override_enabled = true
|
||||||
|
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||||
|
}
|
||||||
|
|
||||||
const credentials = {
|
const credentials = {
|
||||||
...tokenInfo,
|
...tokenInfo,
|
||||||
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
||||||
@@ -3412,6 +3643,7 @@ const handleExchangeCode = async () => {
|
|||||||
|
|
||||||
switch (form.platform) {
|
switch (form.platform) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
|
case 'sora':
|
||||||
return handleOpenAIExchange(authCode)
|
return handleOpenAIExchange(authCode)
|
||||||
case 'gemini':
|
case 'gemini':
|
||||||
return handleGeminiExchange(authCode)
|
return handleGeminiExchange(authCode)
|
||||||
@@ -3486,6 +3718,12 @@ const handleCookieAuth = async (sessionKey: string) => {
|
|||||||
extra.session_id_masking_enabled = true
|
extra.session_id_masking_enabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add cache TTL override settings
|
||||||
|
if (cacheTTLOverrideEnabled.value) {
|
||||||
|
extra.cache_ttl_override_enabled = true
|
||||||
|
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||||
|
}
|
||||||
|
|
||||||
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
// Merge interceptWarmupRequests into credentials
|
// Merge interceptWarmupRequests into credentials
|
||||||
|
|||||||
@@ -975,6 +975,46 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Cache TTL Override -->
|
||||||
|
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.cacheTTLOverride.label') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="cacheTTLOverrideEnabled = !cacheTTLOverrideEnabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
cacheTTLOverrideEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
cacheTTLOverrideEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div v-if="cacheTTLOverrideEnabled" class="mt-3">
|
||||||
|
<label class="input-label text-xs">{{ t('admin.accounts.quotaControl.cacheTTLOverride.target') }}</label>
|
||||||
|
<select
|
||||||
|
v-model="cacheTTLOverrideTarget"
|
||||||
|
class="mt-1 block w-full rounded-md border border-gray-300 bg-white px-3 py-2 text-sm shadow-sm focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500 dark:border-dark-500 dark:bg-dark-700 dark:text-white"
|
||||||
|
>
|
||||||
|
<option value="5m">5m</option>
|
||||||
|
<option value="1h">1h</option>
|
||||||
|
</select>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
@@ -1177,6 +1217,8 @@ const maxSessions = ref<number | null>(null)
|
|||||||
const sessionIdleTimeout = ref<number | null>(null)
|
const sessionIdleTimeout = ref<number | null>(null)
|
||||||
const tlsFingerprintEnabled = ref(false)
|
const tlsFingerprintEnabled = ref(false)
|
||||||
const sessionIdMaskingEnabled = ref(false)
|
const sessionIdMaskingEnabled = ref(false)
|
||||||
|
const cacheTTLOverrideEnabled = ref(false)
|
||||||
|
const cacheTTLOverrideTarget = ref<string>('5m')
|
||||||
|
|
||||||
// OpenAI 自动透传开关(OAuth/API Key)
|
// OpenAI 自动透传开关(OAuth/API Key)
|
||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
@@ -1581,6 +1623,8 @@ function loadQuotaControlSettings(account: Account) {
|
|||||||
sessionIdleTimeout.value = null
|
sessionIdleTimeout.value = null
|
||||||
tlsFingerprintEnabled.value = false
|
tlsFingerprintEnabled.value = false
|
||||||
sessionIdMaskingEnabled.value = false
|
sessionIdMaskingEnabled.value = false
|
||||||
|
cacheTTLOverrideEnabled.value = false
|
||||||
|
cacheTTLOverrideTarget.value = '5m'
|
||||||
|
|
||||||
// Only applies to Anthropic OAuth/SetupToken accounts
|
// Only applies to Anthropic OAuth/SetupToken accounts
|
||||||
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
|
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
|
||||||
@@ -1609,6 +1653,12 @@ function loadQuotaControlSettings(account: Account) {
|
|||||||
if (account.session_id_masking_enabled === true) {
|
if (account.session_id_masking_enabled === true) {
|
||||||
sessionIdMaskingEnabled.value = true
|
sessionIdMaskingEnabled.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load cache TTL override setting
|
||||||
|
if (account.cache_ttl_override_enabled === true) {
|
||||||
|
cacheTTLOverrideEnabled.value = true
|
||||||
|
cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatTempUnschedKeywords(value: unknown) {
|
function formatTempUnschedKeywords(value: unknown) {
|
||||||
@@ -1820,6 +1870,15 @@ const handleSubmit = async () => {
|
|||||||
delete newExtra.session_id_masking_enabled
|
delete newExtra.session_id_masking_enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache TTL override setting
|
||||||
|
if (cacheTTLOverrideEnabled.value) {
|
||||||
|
newExtra.cache_ttl_override_enabled = true
|
||||||
|
newExtra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||||
|
} else {
|
||||||
|
delete newExtra.cache_ttl_override_enabled
|
||||||
|
delete newExtra.cache_ttl_override_target
|
||||||
|
}
|
||||||
|
|
||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,17 @@
|
|||||||
t(getOAuthKey('refreshTokenAuth'))
|
t(getOAuthKey('refreshTokenAuth'))
|
||||||
}}</span>
|
}}</span>
|
||||||
</label>
|
</label>
|
||||||
|
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
|
||||||
|
<input
|
||||||
|
v-model="inputMethod"
|
||||||
|
type="radio"
|
||||||
|
value="session_token"
|
||||||
|
class="text-blue-600 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
||||||
|
t(getOAuthKey('sessionTokenAuth'))
|
||||||
|
}}</span>
|
||||||
|
</label>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -135,6 +146,87 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Session Token Input (Sora) -->
|
||||||
|
<div v-if="inputMethod === 'session_token'" class="space-y-4">
|
||||||
|
<div
|
||||||
|
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
||||||
|
>
|
||||||
|
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
{{ t(getOAuthKey('sessionTokenDesc')) }}
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="mb-4">
|
||||||
|
<label
|
||||||
|
class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<Icon name="key" size="sm" class="text-blue-500" />
|
||||||
|
Session Token
|
||||||
|
<span
|
||||||
|
v-if="parsedSessionTokenCount > 1"
|
||||||
|
class="rounded-full bg-blue-500 px-2 py-0.5 text-xs text-white"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.oauth.keysCount', { count: parsedSessionTokenCount }) }}
|
||||||
|
</span>
|
||||||
|
</label>
|
||||||
|
<textarea
|
||||||
|
v-model="sessionTokenInput"
|
||||||
|
rows="3"
|
||||||
|
class="input w-full resize-y font-mono text-sm"
|
||||||
|
:placeholder="t(getOAuthKey('sessionTokenPlaceholder'))"
|
||||||
|
></textarea>
|
||||||
|
<p
|
||||||
|
v-if="parsedSessionTokenCount > 1"
|
||||||
|
class="mt-1 text-xs text-blue-600 dark:text-blue-400"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="error"
|
||||||
|
class="mb-4 rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30"
|
||||||
|
>
|
||||||
|
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">
|
||||||
|
{{ error }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary w-full"
|
||||||
|
:disabled="loading || !sessionTokenInput.trim()"
|
||||||
|
@click="handleValidateSessionToken"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
v-if="loading"
|
||||||
|
class="-ml-1 mr-2 h-4 w-4 animate-spin"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<circle
|
||||||
|
class="opacity-25"
|
||||||
|
cx="12"
|
||||||
|
cy="12"
|
||||||
|
r="10"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="4"
|
||||||
|
></circle>
|
||||||
|
<path
|
||||||
|
class="opacity-75"
|
||||||
|
fill="currentColor"
|
||||||
|
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||||
|
></path>
|
||||||
|
</svg>
|
||||||
|
<Icon v-else name="sparkles" size="sm" class="mr-2" />
|
||||||
|
{{
|
||||||
|
loading
|
||||||
|
? t(getOAuthKey('validating'))
|
||||||
|
: t(getOAuthKey('validateAndCreate'))
|
||||||
|
}}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Cookie Auto-Auth Form -->
|
<!-- Cookie Auto-Auth Form -->
|
||||||
<div v-if="inputMethod === 'cookie'" class="space-y-4">
|
<div v-if="inputMethod === 'cookie'" class="space-y-4">
|
||||||
<div
|
<div
|
||||||
@@ -521,13 +613,14 @@ interface Props {
|
|||||||
error?: string
|
error?: string
|
||||||
showHelp?: boolean
|
showHelp?: boolean
|
||||||
showProxyWarning?: boolean
|
showProxyWarning?: boolean
|
||||||
allowMultiple?: boolean
|
allowMultiple?: boolean
|
||||||
methodLabel?: string
|
methodLabel?: string
|
||||||
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
||||||
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
||||||
platform?: AccountPlatform // Platform type for different UI/text
|
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
|
||||||
showProjectId?: boolean // New prop to control project ID visibility
|
platform?: AccountPlatform // Platform type for different UI/text
|
||||||
}
|
showProjectId?: boolean // New prop to control project ID visibility
|
||||||
|
}
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
authUrl: '',
|
authUrl: '',
|
||||||
@@ -540,6 +633,7 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
methodLabel: 'Authorization Method',
|
methodLabel: 'Authorization Method',
|
||||||
showCookieOption: true,
|
showCookieOption: true,
|
||||||
showRefreshTokenOption: false,
|
showRefreshTokenOption: false,
|
||||||
|
showSessionTokenOption: false,
|
||||||
platform: 'anthropic',
|
platform: 'anthropic',
|
||||||
showProjectId: true
|
showProjectId: true
|
||||||
})
|
})
|
||||||
@@ -549,6 +643,7 @@ const emit = defineEmits<{
|
|||||||
'exchange-code': [code: string]
|
'exchange-code': [code: string]
|
||||||
'cookie-auth': [sessionKey: string]
|
'cookie-auth': [sessionKey: string]
|
||||||
'validate-refresh-token': [refreshToken: string]
|
'validate-refresh-token': [refreshToken: string]
|
||||||
|
'validate-session-token': [sessionToken: string]
|
||||||
'update:inputMethod': [method: AuthInputMethod]
|
'update:inputMethod': [method: AuthInputMethod]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
@@ -587,12 +682,13 @@ const inputMethod = ref<AuthInputMethod>(props.showCookieOption ? 'manual' : 'ma
|
|||||||
const authCodeInput = ref('')
|
const authCodeInput = ref('')
|
||||||
const sessionKeyInput = ref('')
|
const sessionKeyInput = ref('')
|
||||||
const refreshTokenInput = ref('')
|
const refreshTokenInput = ref('')
|
||||||
|
const sessionTokenInput = ref('')
|
||||||
const showHelpDialog = ref(false)
|
const showHelpDialog = ref(false)
|
||||||
const oauthState = ref('')
|
const oauthState = ref('')
|
||||||
const projectId = ref('')
|
const projectId = ref('')
|
||||||
|
|
||||||
// Computed: show method selection when either cookie or refresh token option is enabled
|
// Computed: show method selection when either cookie or refresh token option is enabled
|
||||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption)
|
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption)
|
||||||
|
|
||||||
// Clipboard
|
// Clipboard
|
||||||
const { copied, copyToClipboard } = useClipboard()
|
const { copied, copyToClipboard } = useClipboard()
|
||||||
@@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => {
|
|||||||
.filter((rt) => rt).length
|
.filter((rt) => rt).length
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const parsedSessionTokenCount = computed(() => {
|
||||||
|
return sessionTokenInput.value
|
||||||
|
.split('\n')
|
||||||
|
.map((st) => st.trim())
|
||||||
|
.filter((st) => st).length
|
||||||
|
})
|
||||||
|
|
||||||
// Watchers
|
// Watchers
|
||||||
watch(inputMethod, (newVal) => {
|
watch(inputMethod, (newVal) => {
|
||||||
emit('update:inputMethod', newVal)
|
emit('update:inputMethod', newVal)
|
||||||
@@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => {
|
|||||||
const url = new URL(trimmed)
|
const url = new URL(trimmed)
|
||||||
const code = url.searchParams.get('code')
|
const code = url.searchParams.get('code')
|
||||||
const stateParam = url.searchParams.get('state')
|
const stateParam = url.searchParams.get('state')
|
||||||
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
|
if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
|
||||||
oauthState.value = stateParam
|
oauthState.value = stateParam
|
||||||
}
|
}
|
||||||
if (code && code !== trimmed) {
|
if (code && code !== trimmed) {
|
||||||
@@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => {
|
|||||||
// If URL parsing fails, try regex extraction
|
// If URL parsing fails, try regex extraction
|
||||||
const match = trimmed.match(/[?&]code=([^&]+)/)
|
const match = trimmed.match(/[?&]code=([^&]+)/)
|
||||||
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
|
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
|
||||||
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
|
if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
|
||||||
oauthState.value = stateMatch[1]
|
oauthState.value = stateMatch[1]
|
||||||
}
|
}
|
||||||
if (match && match[1] && match[1] !== trimmed) {
|
if (match && match[1] && match[1] !== trimmed) {
|
||||||
@@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleValidateSessionToken = () => {
|
||||||
|
if (sessionTokenInput.value.trim()) {
|
||||||
|
emit('validate-session-token', sessionTokenInput.value.trim())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Expose methods and state
|
// Expose methods and state
|
||||||
defineExpose({
|
defineExpose({
|
||||||
authCode: authCodeInput,
|
authCode: authCodeInput,
|
||||||
@@ -687,6 +796,7 @@ defineExpose({
|
|||||||
projectId,
|
projectId,
|
||||||
sessionKey: sessionKeyInput,
|
sessionKey: sessionKeyInput,
|
||||||
refreshToken: refreshTokenInput,
|
refreshToken: refreshTokenInput,
|
||||||
|
sessionToken: sessionTokenInput,
|
||||||
inputMethod,
|
inputMethod,
|
||||||
reset: () => {
|
reset: () => {
|
||||||
authCodeInput.value = ''
|
authCodeInput.value = ''
|
||||||
@@ -694,6 +804,7 @@ defineExpose({
|
|||||||
projectId.value = ''
|
projectId.value = ''
|
||||||
sessionKeyInput.value = ''
|
sessionKeyInput.value = ''
|
||||||
refreshTokenInput.value = ''
|
refreshTokenInput.value = ''
|
||||||
|
sessionTokenInput.value = ''
|
||||||
inputMethod.value = 'manual'
|
inputMethod.value = 'manual'
|
||||||
showHelpDialog.value = false
|
showHelpDialog.value = false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
<div
|
<div
|
||||||
:class="[
|
:class="[
|
||||||
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
||||||
isOpenAI
|
isOpenAILike
|
||||||
? 'from-green-500 to-green-600'
|
? 'from-green-500 to-green-600'
|
||||||
: isGemini
|
: isGemini
|
||||||
? 'from-blue-500 to-blue-600'
|
? 'from-blue-500 to-blue-600'
|
||||||
@@ -33,6 +33,8 @@
|
|||||||
{{
|
{{
|
||||||
isOpenAI
|
isOpenAI
|
||||||
? t('admin.accounts.openaiAccount')
|
? t('admin.accounts.openaiAccount')
|
||||||
|
: isSora
|
||||||
|
? t('admin.accounts.soraAccount')
|
||||||
: isGemini
|
: isGemini
|
||||||
? t('admin.accounts.geminiAccount')
|
? t('admin.accounts.geminiAccount')
|
||||||
: isAntigravity
|
: isAntigravity
|
||||||
@@ -128,7 +130,7 @@
|
|||||||
:show-cookie-option="isAnthropic"
|
:show-cookie-option="isAnthropic"
|
||||||
:allow-multiple="false"
|
:allow-multiple="false"
|
||||||
:method-label="t('admin.accounts.inputMethod')"
|
:method-label="t('admin.accounts.inputMethod')"
|
||||||
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
:platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
||||||
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@@ -224,7 +226,8 @@ const { t } = useI18n()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const claudeOAuth = useAccountOAuth()
|
const claudeOAuth = useAccountOAuth()
|
||||||
const openaiOAuth = useOpenAIOAuth()
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
|
||||||
const geminiOAuth = useGeminiOAuth()
|
const geminiOAuth = useGeminiOAuth()
|
||||||
const antigravityOAuth = useAntigravityOAuth()
|
const antigravityOAuth = useAntigravityOAuth()
|
||||||
|
|
||||||
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
|
|||||||
|
|
||||||
// Computed - check platform
|
// Computed - check platform
|
||||||
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
||||||
|
const isSora = computed(() => props.account?.platform === 'sora')
|
||||||
|
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
|
||||||
const isGemini = computed(() => props.account?.platform === 'gemini')
|
const isGemini = computed(() => props.account?.platform === 'gemini')
|
||||||
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
||||||
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
||||||
|
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed - current OAuth state based on platform
|
// Computed - current OAuth state based on platform
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.authUrl.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (isGemini.value) return geminiOAuth.authUrl.value
|
if (isGemini.value) return geminiOAuth.authUrl.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
||||||
return claudeOAuth.authUrl.value
|
return claudeOAuth.authUrl.value
|
||||||
})
|
})
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.sessionId.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (isGemini.value) return geminiOAuth.sessionId.value
|
if (isGemini.value) return geminiOAuth.sessionId.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
||||||
return claudeOAuth.sessionId.value
|
return claudeOAuth.sessionId.value
|
||||||
})
|
})
|
||||||
const currentLoading = computed(() => {
|
const currentLoading = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.loading.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
|
||||||
if (isGemini.value) return geminiOAuth.loading.value
|
if (isGemini.value) return geminiOAuth.loading.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.loading.value
|
if (isAntigravity.value) return antigravityOAuth.loading.value
|
||||||
return claudeOAuth.loading.value
|
return claudeOAuth.loading.value
|
||||||
})
|
})
|
||||||
const currentError = computed(() => {
|
const currentError = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.error.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
|
||||||
if (isGemini.value) return geminiOAuth.error.value
|
if (isGemini.value) return geminiOAuth.error.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.error.value
|
if (isAntigravity.value) return antigravityOAuth.error.value
|
||||||
return claudeOAuth.error.value
|
return claudeOAuth.error.value
|
||||||
@@ -269,8 +275,8 @@ const currentError = computed(() => {
|
|||||||
|
|
||||||
// Computed
|
// Computed
|
||||||
const isManualInputMethod = computed(() => {
|
const isManualInputMethod = computed(() => {
|
||||||
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
|
// OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
|
||||||
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||||
})
|
})
|
||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
@@ -313,6 +319,7 @@ const resetState = () => {
|
|||||||
geminiOAuthType.value = 'code_assist'
|
geminiOAuthType.value = 'code_assist'
|
||||||
claudeOAuth.resetState()
|
claudeOAuth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -325,8 +332,8 @@ const handleClose = () => {
|
|||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||||
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
||||||
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
|
|||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (!authCode.trim()) return
|
if (!authCode.trim()) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
// OpenAI OAuth flow
|
// OpenAI OAuth flow
|
||||||
const sessionId = openaiOAuth.sessionId.value
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
const sessionId = oauthClient.sessionId.value
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
sessionId,
|
sessionId,
|
||||||
|
stateToUse,
|
||||||
props.account.proxy_id
|
props.account.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
// Build credentials and extra info
|
// Build credentials and extra info
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
|
const extra = oauthClient.buildExtraInfo(tokenInfo)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update account with new credentials
|
// Update account with new credentials
|
||||||
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
|
|||||||
emit('reauthorized')
|
emit('reauthorized')
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
}
|
}
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const sessionId = geminiOAuth.sessionId.value
|
const sessionId = geminiOAuth.sessionId.value
|
||||||
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleCookieAuth = async (sessionKey: string) => {
|
const handleCookieAuth = async (sessionKey: string) => {
|
||||||
if (!props.account || isOpenAI.value) return
|
if (!props.account || isOpenAILike.value) return
|
||||||
|
|
||||||
claudeOAuth.loading.value = true
|
claudeOAuth.loading.value = true
|
||||||
claudeOAuth.error.value = ''
|
claudeOAuth.error.value = ''
|
||||||
|
|||||||
@@ -10,16 +10,21 @@
|
|||||||
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
|
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
|
||||||
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
|
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
|
||||||
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
|
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
|
||||||
|
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue'; import { useI18n } from 'vue-i18n'; import Select from '@/components/common/Select.vue'; import SearchInput from '@/components/common/SearchInput.vue'
|
import { computed } from 'vue'; import { useI18n } from 'vue-i18n'; import Select from '@/components/common/Select.vue'; import SearchInput from '@/components/common/SearchInput.vue'
|
||||||
const props = defineProps(['searchQuery', 'filters']); const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); const { t } = useI18n()
|
import type { AdminGroup } from '@/types'
|
||||||
|
const props = defineProps<{ searchQuery: string; filters: Record<string, any>; groups?: AdminGroup[] }>()
|
||||||
|
const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); const { t } = useI18n()
|
||||||
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
|
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
|
||||||
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
|
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
|
||||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||||
|
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
||||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }])
|
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }])
|
||||||
|
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="space-y-1.5">
|
<div v-if="!isSoraAccount" class="space-y-1.5">
|
||||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
{{ t('admin.accounts.selectTestModel') }}
|
{{ t('admin.accounts.selectTestModel') }}
|
||||||
</label>
|
</label>
|
||||||
@@ -54,6 +54,12 @@
|
|||||||
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
|
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-else
|
||||||
|
class="rounded-lg border border-blue-200 bg-blue-50 px-3 py-2 text-xs text-blue-700 dark:border-blue-700 dark:bg-blue-900/20 dark:text-blue-300"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.soraTestHint') }}
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Terminal Output -->
|
<!-- Terminal Output -->
|
||||||
<div class="group relative">
|
<div class="group relative">
|
||||||
@@ -114,12 +120,12 @@
|
|||||||
<div class="flex items-center gap-3">
|
<div class="flex items-center gap-3">
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="grid" size="sm" :stroke-width="2" />
|
<Icon name="grid" size="sm" :stroke-width="2" />
|
||||||
{{ t('admin.accounts.testModel') }}
|
{{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="chat" size="sm" :stroke-width="2" />
|
<Icon name="chat" size="sm" :stroke-width="2" />
|
||||||
{{ t('admin.accounts.testPrompt') }}
|
{{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -135,10 +141,10 @@
|
|||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
@click="startTest"
|
@click="startTest"
|
||||||
:disabled="status === 'connecting' || !selectedModelId"
|
:disabled="status === 'connecting' || (!isSoraAccount && !selectedModelId)"
|
||||||
:class="[
|
:class="[
|
||||||
'flex items-center gap-2 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
'flex items-center gap-2 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||||
status === 'connecting' || !selectedModelId
|
status === 'connecting' || (!isSoraAccount && !selectedModelId)
|
||||||
? 'cursor-not-allowed bg-primary-400 text-white'
|
? 'cursor-not-allowed bg-primary-400 text-white'
|
||||||
: status === 'success'
|
: status === 'success'
|
||||||
? 'bg-green-500 text-white hover:bg-green-600'
|
? 'bg-green-500 text-white hover:bg-green-600'
|
||||||
@@ -172,7 +178,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch, nextTick } from 'vue'
|
import { computed, ref, watch, nextTick } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
@@ -207,6 +213,7 @@ const availableModels = ref<ClaudeModel[]>([])
|
|||||||
const selectedModelId = ref('')
|
const selectedModelId = ref('')
|
||||||
const loadingModels = ref(false)
|
const loadingModels = ref(false)
|
||||||
let eventSource: EventSource | null = null
|
let eventSource: EventSource | null = null
|
||||||
|
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
||||||
|
|
||||||
// Load available models when modal opens
|
// Load available models when modal opens
|
||||||
watch(
|
watch(
|
||||||
@@ -223,6 +230,12 @@ watch(
|
|||||||
|
|
||||||
const loadAvailableModels = async () => {
|
const loadAvailableModels = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
if (props.account.platform === 'sora') {
|
||||||
|
availableModels.value = []
|
||||||
|
selectedModelId.value = ''
|
||||||
|
loadingModels.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
loadingModels.value = true
|
loadingModels.value = true
|
||||||
selectedModelId.value = '' // Reset selection before loading
|
selectedModelId.value = '' // Reset selection before loading
|
||||||
@@ -290,7 +303,7 @@ const scrollToBottom = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const startTest = async () => {
|
const startTest = async () => {
|
||||||
if (!props.account || !selectedModelId.value) return
|
if (!props.account || (!isSoraAccount.value && !selectedModelId.value)) return
|
||||||
|
|
||||||
resetState()
|
resetState()
|
||||||
status.value = 'connecting'
|
status.value = 'connecting'
|
||||||
@@ -311,7 +324,9 @@ const startTest = async () => {
|
|||||||
Authorization: `Bearer ${localStorage.getItem('auth_token')}`,
|
Authorization: `Bearer ${localStorage.getItem('auth_token')}`,
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ model_id: selectedModelId.value })
|
body: JSON.stringify(
|
||||||
|
isSoraAccount.value ? {} : { model_id: selectedModelId.value }
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@@ -368,7 +383,10 @@ const handleEvent = (event: {
|
|||||||
if (event.model) {
|
if (event.model) {
|
||||||
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
||||||
}
|
}
|
||||||
addLine(t('admin.accounts.sendingTestMessage'), 'text-gray-400')
|
addLine(
|
||||||
|
isSoraAccount.value ? t('admin.accounts.soraTestingFlow') : t('admin.accounts.sendingTestMessage'),
|
||||||
|
'text-gray-400'
|
||||||
|
)
|
||||||
addLine('', 'text-gray-300')
|
addLine('', 'text-gray-300')
|
||||||
addLine(t('admin.accounts.response'), 'text-yellow-400')
|
addLine(t('admin.accounts.response'), 'text-yellow-400')
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
<div
|
<div
|
||||||
:class="[
|
:class="[
|
||||||
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
||||||
isOpenAI
|
isOpenAILike
|
||||||
? 'from-green-500 to-green-600'
|
? 'from-green-500 to-green-600'
|
||||||
: isGemini
|
: isGemini
|
||||||
? 'from-blue-500 to-blue-600'
|
? 'from-blue-500 to-blue-600'
|
||||||
@@ -33,6 +33,8 @@
|
|||||||
{{
|
{{
|
||||||
isOpenAI
|
isOpenAI
|
||||||
? t('admin.accounts.openaiAccount')
|
? t('admin.accounts.openaiAccount')
|
||||||
|
: isSora
|
||||||
|
? t('admin.accounts.soraAccount')
|
||||||
: isGemini
|
: isGemini
|
||||||
? t('admin.accounts.geminiAccount')
|
? t('admin.accounts.geminiAccount')
|
||||||
: isAntigravity
|
: isAntigravity
|
||||||
@@ -128,7 +130,7 @@
|
|||||||
:show-cookie-option="isAnthropic"
|
:show-cookie-option="isAnthropic"
|
||||||
:allow-multiple="false"
|
:allow-multiple="false"
|
||||||
:method-label="t('admin.accounts.inputMethod')"
|
:method-label="t('admin.accounts.inputMethod')"
|
||||||
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
:platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
||||||
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@@ -224,7 +226,8 @@ const { t } = useI18n()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const claudeOAuth = useAccountOAuth()
|
const claudeOAuth = useAccountOAuth()
|
||||||
const openaiOAuth = useOpenAIOAuth()
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
|
||||||
const geminiOAuth = useGeminiOAuth()
|
const geminiOAuth = useGeminiOAuth()
|
||||||
const antigravityOAuth = useAntigravityOAuth()
|
const antigravityOAuth = useAntigravityOAuth()
|
||||||
|
|
||||||
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
|
|||||||
|
|
||||||
// Computed - check platform
|
// Computed - check platform
|
||||||
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
||||||
|
const isSora = computed(() => props.account?.platform === 'sora')
|
||||||
|
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
|
||||||
const isGemini = computed(() => props.account?.platform === 'gemini')
|
const isGemini = computed(() => props.account?.platform === 'gemini')
|
||||||
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
||||||
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
||||||
|
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed - current OAuth state based on platform
|
// Computed - current OAuth state based on platform
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.authUrl.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (isGemini.value) return geminiOAuth.authUrl.value
|
if (isGemini.value) return geminiOAuth.authUrl.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
||||||
return claudeOAuth.authUrl.value
|
return claudeOAuth.authUrl.value
|
||||||
})
|
})
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.sessionId.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (isGemini.value) return geminiOAuth.sessionId.value
|
if (isGemini.value) return geminiOAuth.sessionId.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
||||||
return claudeOAuth.sessionId.value
|
return claudeOAuth.sessionId.value
|
||||||
})
|
})
|
||||||
const currentLoading = computed(() => {
|
const currentLoading = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.loading.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
|
||||||
if (isGemini.value) return geminiOAuth.loading.value
|
if (isGemini.value) return geminiOAuth.loading.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.loading.value
|
if (isAntigravity.value) return antigravityOAuth.loading.value
|
||||||
return claudeOAuth.loading.value
|
return claudeOAuth.loading.value
|
||||||
})
|
})
|
||||||
const currentError = computed(() => {
|
const currentError = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.error.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
|
||||||
if (isGemini.value) return geminiOAuth.error.value
|
if (isGemini.value) return geminiOAuth.error.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.error.value
|
if (isAntigravity.value) return antigravityOAuth.error.value
|
||||||
return claudeOAuth.error.value
|
return claudeOAuth.error.value
|
||||||
@@ -269,8 +275,8 @@ const currentError = computed(() => {
|
|||||||
|
|
||||||
// Computed
|
// Computed
|
||||||
const isManualInputMethod = computed(() => {
|
const isManualInputMethod = computed(() => {
|
||||||
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
|
// OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
|
||||||
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||||
})
|
})
|
||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
@@ -313,6 +319,7 @@ const resetState = () => {
|
|||||||
geminiOAuthType.value = 'code_assist'
|
geminiOAuthType.value = 'code_assist'
|
||||||
claudeOAuth.resetState()
|
claudeOAuth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -325,8 +332,8 @@ const handleClose = () => {
|
|||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||||
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
||||||
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
|
|||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (!authCode.trim()) return
|
if (!authCode.trim()) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
// OpenAI OAuth flow
|
// OpenAI OAuth flow
|
||||||
const sessionId = openaiOAuth.sessionId.value
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
const sessionId = oauthClient.sessionId.value
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
sessionId,
|
sessionId,
|
||||||
|
stateToUse,
|
||||||
props.account.proxy_id
|
props.account.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
// Build credentials and extra info
|
// Build credentials and extra info
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
|
const extra = oauthClient.buildExtraInfo(tokenInfo)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update account with new credentials
|
// Update account with new credentials
|
||||||
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
|
|||||||
emit('reauthorized', updatedAccount)
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
}
|
}
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const sessionId = geminiOAuth.sessionId.value
|
const sessionId = geminiOAuth.sessionId.value
|
||||||
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleCookieAuth = async (sessionKey: string) => {
|
const handleCookieAuth = async (sessionKey: string) => {
|
||||||
if (!props.account || isOpenAI.value) return
|
if (!props.account || isOpenAILike.value) return
|
||||||
|
|
||||||
claudeOAuth.loading.value = true
|
claudeOAuth.loading.value = true
|
||||||
claudeOAuth.error.value = ''
|
claudeOAuth.error.value = ''
|
||||||
|
|||||||
@@ -70,6 +70,8 @@
|
|||||||
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
|
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
|
||||||
<svg class="h-3.5 w-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" /></svg>
|
<svg class="h-3.5 w-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" /></svg>
|
||||||
<span class="font-medium text-amber-600 dark:text-amber-400">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
|
<span class="font-medium text-amber-600 dark:text-amber-400">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
|
||||||
|
<span v-if="row.cache_creation_1h_tokens > 0" class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-orange-100 text-orange-600 ring-1 ring-inset ring-orange-200 dark:bg-orange-500/20 dark:text-orange-400 dark:ring-orange-500/30">1h</span>
|
||||||
|
<span v-if="row.cache_ttl_overridden" :title="t('usage.cacheTtlOverriddenHint')" class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-rose-100 text-rose-600 ring-1 ring-inset ring-rose-200 dark:bg-rose-500/20 dark:text-rose-400 dark:ring-rose-500/30 cursor-help">R</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -157,9 +159,36 @@
|
|||||||
<span class="text-gray-400">{{ t('admin.usage.outputTokens') }}</span>
|
<span class="text-gray-400">{{ t('admin.usage.outputTokens') }}</span>
|
||||||
<span class="font-medium text-white">{{ tokenTooltipData.output_tokens.toLocaleString() }}</span>
|
<span class="font-medium text-white">{{ tokenTooltipData.output_tokens.toLocaleString() }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="tokenTooltipData && tokenTooltipData.cache_creation_tokens > 0" class="flex items-center justify-between gap-4">
|
<div v-if="tokenTooltipData && tokenTooltipData.cache_creation_tokens > 0">
|
||||||
<span class="text-gray-400">{{ t('admin.usage.cacheCreationTokens') }}</span>
|
<!-- 有 5m/1h 明细时,展开显示 -->
|
||||||
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_tokens.toLocaleString() }}</span>
|
<template v-if="tokenTooltipData.cache_creation_5m_tokens > 0 || tokenTooltipData.cache_creation_1h_tokens > 0">
|
||||||
|
<div v-if="tokenTooltipData.cache_creation_5m_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400 flex items-center gap-1.5">
|
||||||
|
{{ t('admin.usage.cacheCreation5mTokens') }}
|
||||||
|
<span class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-amber-500/20 text-amber-400 ring-1 ring-inset ring-amber-500/30">5m</span>
|
||||||
|
</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_5m_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
|
<div v-if="tokenTooltipData.cache_creation_1h_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400 flex items-center gap-1.5">
|
||||||
|
{{ t('admin.usage.cacheCreation1hTokens') }}
|
||||||
|
<span class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-orange-500/20 text-orange-400 ring-1 ring-inset ring-orange-500/30">1h</span>
|
||||||
|
</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_1h_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<!-- 无明细时,只显示聚合值 -->
|
||||||
|
<div v-else class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400">{{ t('admin.usage.cacheCreationTokens') }}</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div v-if="tokenTooltipData && tokenTooltipData.cache_ttl_overridden" class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400 flex items-center gap-1.5">
|
||||||
|
{{ t('usage.cacheTtlOverriddenLabel') }}
|
||||||
|
<span class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-rose-500/20 text-rose-400 ring-1 ring-inset ring-rose-500/30">R-{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? '5m' : '1H' }}</span>
|
||||||
|
</span>
|
||||||
|
<span class="font-medium text-rose-400">{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? t('usage.cacheTtlOverridden1h') : t('usage.cacheTtlOverridden5m') }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="tokenTooltipData && tokenTooltipData.cache_read_tokens > 0" class="flex items-center justify-between gap-4">
|
<div v-if="tokenTooltipData && tokenTooltipData.cache_read_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
<span class="text-gray-400">{{ t('admin.usage.cacheReadTokens') }}</span>
|
<span class="text-gray-400">{{ t('admin.usage.cacheReadTokens') }}</span>
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
<div class="min-w-0 flex-1">
|
<div class="min-w-0 flex-1">
|
||||||
<p class="stat-label truncate">{{ title }}</p>
|
<p class="stat-label truncate">{{ title }}</p>
|
||||||
<div class="mt-1 flex items-baseline gap-2">
|
<div class="mt-1 flex items-baseline gap-2">
|
||||||
<p class="stat-value">{{ formattedValue }}</p>
|
<p class="stat-value" :title="String(formattedValue)">{{ formattedValue }}</p>
|
||||||
<span v-if="change !== undefined" :class="['stat-trend', trendClass]">
|
<span v-if="change !== undefined" :class="['stat-trend', trendClass]">
|
||||||
<Icon
|
<Icon
|
||||||
v-if="changeType !== 'neutral'"
|
v-if="changeType !== 'neutral'"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
<div class="sidebar-header">
|
<div class="sidebar-header">
|
||||||
<!-- Custom Logo or Default Logo -->
|
<!-- Custom Logo or Default Logo -->
|
||||||
<div class="flex h-9 w-9 items-center justify-center overflow-hidden rounded-xl shadow-glow">
|
<div class="flex h-9 w-9 items-center justify-center overflow-hidden rounded-xl shadow-glow">
|
||||||
<img :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
<img v-if="settingsLoaded" :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
||||||
</div>
|
</div>
|
||||||
<transition name="fade">
|
<transition name="fade">
|
||||||
<div v-if="!sidebarCollapsed" class="flex flex-col">
|
<div v-if="!sidebarCollapsed" class="flex flex-col">
|
||||||
@@ -167,6 +167,7 @@ const isDark = ref(document.documentElement.classList.contains('dark'))
|
|||||||
const siteName = computed(() => appStore.siteName)
|
const siteName = computed(() => appStore.siteName)
|
||||||
const siteLogo = computed(() => appStore.siteLogo)
|
const siteLogo = computed(() => appStore.siteLogo)
|
||||||
const siteVersion = computed(() => appStore.siteVersion)
|
const siteVersion = computed(() => appStore.siteVersion)
|
||||||
|
const settingsLoaded = computed(() => appStore.publicSettingsLoaded)
|
||||||
|
|
||||||
// SVG Icon Components
|
// SVG Icon Components
|
||||||
const DashboardIcon = {
|
const DashboardIcon = {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
|
|||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
|
|
||||||
export type AddMethod = 'oauth' | 'setup-token'
|
export type AddMethod = 'oauth' | 'setup-token'
|
||||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token'
|
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token'
|
||||||
|
|
||||||
export interface OAuthState {
|
export interface OAuthState {
|
||||||
authUrl: string
|
authUrl: string
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user