mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Merge branch 'test' into release
This commit is contained in:
5
backend/.gosec.json
Normal file
5
backend/.gosec.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"global": {
|
||||
"exclude": "G704"
|
||||
}
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
0.1.74.9
|
||||
0.1.83.2
|
||||
|
||||
@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
|
||||
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
|
||||
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||
|
||||
@@ -669,6 +669,7 @@ var (
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
|
||||
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "api_key_id", Type: field.TypeInt64},
|
||||
{Name: "account_id", Type: field.TypeInt64},
|
||||
@@ -684,31 +685,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -717,32 +718,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -757,12 +758,12 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -15980,6 +15980,7 @@ type UsageLogMutation struct {
|
||||
addimage_count *int
|
||||
image_size *string
|
||||
media_type *string
|
||||
cache_ttl_overridden *bool
|
||||
created_at *time.Time
|
||||
clearedFields map[string]struct{}
|
||||
user *int64
|
||||
@@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() {
|
||||
delete(m.clearedFields, usagelog.FieldMediaType)
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
||||
m.cache_ttl_overridden = &b
|
||||
}
|
||||
|
||||
// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation.
|
||||
func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) {
|
||||
v := m.cache_ttl_overridden
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err)
|
||||
}
|
||||
return oldValue.CacheTTLOverridden, nil
|
||||
}
|
||||
|
||||
// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field.
|
||||
func (m *UsageLogMutation) ResetCacheTTLOverridden() {
|
||||
m.cache_ttl_overridden = nil
|
||||
}
|
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
|
||||
m.created_at = &t
|
||||
@@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 31)
|
||||
fields := make([]string, 0, 32)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.media_type != nil {
|
||||
fields = append(fields, usagelog.FieldMediaType)
|
||||
}
|
||||
if m.cache_ttl_overridden != nil {
|
||||
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
||||
}
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, usagelog.FieldCreatedAt)
|
||||
}
|
||||
@@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ImageSize()
|
||||
case usagelog.FieldMediaType:
|
||||
return m.MediaType()
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.CacheTTLOverridden()
|
||||
case usagelog.FieldCreatedAt:
|
||||
return m.CreatedAt()
|
||||
}
|
||||
@@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldImageSize(ctx)
|
||||
case usagelog.FieldMediaType:
|
||||
return m.OldMediaType(ctx)
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.OldCacheTTLOverridden(ctx)
|
||||
case usagelog.FieldCreatedAt:
|
||||
return m.OldCreatedAt(ctx)
|
||||
}
|
||||
@@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetMediaType(v)
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetCacheTTLOverridden(v)
|
||||
return nil
|
||||
case usagelog.FieldCreatedAt:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
@@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldMediaType:
|
||||
m.ResetMediaType()
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
m.ResetCacheTTLOverridden()
|
||||
return nil
|
||||
case usagelog.FieldCreatedAt:
|
||||
m.ResetCreatedAt()
|
||||
return nil
|
||||
|
||||
@@ -821,8 +821,12 @@ func init() {
|
||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[30].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||
field.Bool("cache_ttl_overridden").
|
||||
Default(false),
|
||||
|
||||
// 时间戳(只有 created_at,日志不可修改)
|
||||
field.Time("created_at").
|
||||
Default(time.Now).
|
||||
|
||||
@@ -82,6 +82,8 @@ type UsageLog struct {
|
||||
ImageSize *string `json:"image_size,omitempty"`
|
||||
// MediaType holds the value of the "media_type" field.
|
||||
MediaType *string `json:"media_type,omitempty"`
|
||||
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case usagelog.FieldStream:
|
||||
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
@@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.MediaType = new(string)
|
||||
*_m.MediaType = value.String
|
||||
}
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CacheTTLOverridden = value.Bool
|
||||
}
|
||||
case usagelog.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
@@ -562,6 +570,9 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("cache_ttl_overridden=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteByte(')')
|
||||
|
||||
@@ -74,6 +74,8 @@ const (
|
||||
FieldImageSize = "image_size"
|
||||
// FieldMediaType holds the string denoting the media_type field in the database.
|
||||
FieldMediaType = "media_type"
|
||||
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
@@ -158,6 +160,7 @@ var Columns = []string{
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldMediaType,
|
||||
FieldCacheTTLOverridden,
|
||||
FieldCreatedAt,
|
||||
}
|
||||
|
||||
@@ -216,6 +219,8 @@ var (
|
||||
ImageSizeValidator func(string) error
|
||||
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
MediaTypeValidator func(string) error
|
||||
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||
DefaultCacheTTLOverridden bool
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
)
|
||||
@@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
|
||||
@@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
|
||||
func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||
|
||||
@@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||
_c.mutation.SetCacheTTLOverridden(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetCacheTTLOverridden(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() {
|
||||
v := usagelog.DefaultImageCount
|
||||
_c.mutation.SetImageCount(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||
v := usagelog.DefaultCacheTTLOverridden
|
||||
_c.mutation.SetCacheTTLOverridden(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := usagelog.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
|
||||
}
|
||||
@@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
||||
_node.MediaType = &value
|
||||
}
|
||||
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
_node.CacheTTLOverridden = value
|
||||
}
|
||||
if value, ok := _c.mutation.CreatedAt(); ok {
|
||||
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
|
||||
_node.CreatedAt = value
|
||||
@@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldCacheTTLOverridden)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetCacheTTLOverridden(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateCacheTTLOverridden()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetCacheTTLOverridden(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateCacheTTLOverridden()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetCacheTTLOverridden(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.MediaTypeCleared() {
|
||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCacheTTLOverridden(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.MediaTypeCleared() {
|
||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
|
||||
@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -269,17 +271,30 @@ type SoraConfig struct {
|
||||
|
||||
// SoraClientConfig 直连 Sora 客户端配置
|
||||
type SoraClientConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
||||
}
|
||||
|
||||
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
||||
type SoraCurlCFFISidecarConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Impersonate string `mapstructure:"impersonate"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
||||
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
||||
}
|
||||
|
||||
// SoraStorageConfig 媒体存储配置
|
||||
@@ -1111,14 +1126,22 @@ func setDefaults() {
|
||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||
viper.SetDefault("sora.client.max_retries", 3)
|
||||
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||
viper.SetDefault("sora.client.debug", false)
|
||||
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
||||
|
||||
viper.SetDefault("sora.storage.type", "local")
|
||||
viper.SetDefault("sora.storage.local_path", "")
|
||||
@@ -1137,6 +1160,7 @@ func setDefaults() {
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
@@ -1505,6 +1529,9 @@ func (c *Config) Validate() error {
|
||||
if c.Sora.Client.MaxRetries < 0 {
|
||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||
}
|
||||
@@ -1521,6 +1548,18 @@ func (c *Config) Validate() error {
|
||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
||||
}
|
||||
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
||||
}
|
||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||
}
|
||||
|
||||
@@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
||||
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
||||
}
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
||||
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
|
||||
return &service.ProxyQualityCheckResult{
|
||||
ProxyID: id,
|
||||
Score: 95,
|
||||
Grade: "A",
|
||||
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
|
||||
PassedCount: 5,
|
||||
WarnCount: 0,
|
||||
FailedCount: 0,
|
||||
ChallengeCount: 0,
|
||||
CheckedAt: time.Now().Unix(),
|
||||
Items: []service.ProxyQualityCheckItem{
|
||||
{Target: "base_connectivity", Status: "pass", Message: "ok"},
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RT string `json:"rt"`
|
||||
ClientID string `json:"client_id"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
refreshToken = strings.TrimSpace(req.RT)
|
||||
}
|
||||
if refreshToken == "" {
|
||||
response.BadRequest(c, "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
platform := oauthPlatformFromPath(c)
|
||||
if account.Platform != platform {
|
||||
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
platform := oauthPlatformFromPath(c)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
|
||||
@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// CheckQuality handles checking proxy quality across common AI targets.
|
||||
// POST /api/v1/admin/proxies/:id/quality-check
|
||||
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetStats handles getting proxy statistics
|
||||
// GET /api/v1/admin/proxies/:id/stats
|
||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||
|
||||
@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
// 缓存 TTL 强制替换
|
||||
if a.IsCacheTTLOverrideEnabled() {
|
||||
enabled := true
|
||||
out.CacheTTLOverrideEnabled = &enabled
|
||||
target := a.GetCacheTTLOverrideTarget()
|
||||
out.CacheTTLOverrideTarget = &target
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
QualityStatus: p.QualityStatus,
|
||||
QualityScore: p.QualityScore,
|
||||
QualityGrade: p.QualityGrade,
|
||||
QualitySummary: p.QualitySummary,
|
||||
QualityChecked: p.QualityChecked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
ImageSize: l.ImageSize,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
|
||||
@@ -156,6 +156,11 @@ type Account struct {
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
@@ -280,6 +290,9 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// Cache TTL Override 标记
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
streamMode := "force"
|
||||
soraTLSEnabled := true
|
||||
signKey := ""
|
||||
mediaRoot := "/app/data/sora"
|
||||
if cfg != nil {
|
||||
@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
|
||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||
streamMode = mode
|
||||
}
|
||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||
mediaRoot = root
|
||||
@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverBody []byte
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
proxyBound := account.ProxyID != nil
|
||||
proxyID := int64(0)
|
||||
if account.ProxyID != nil {
|
||||
proxyID = *account.ProxyID
|
||||
}
|
||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
switchCount++
|
||||
reqLog.Warn("sora.upstream_failover_switching",
|
||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.String("upstream_error_code", upstreamErrCode),
|
||||
zap.String("upstream_error_message", upstreamErrMsg),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||
continue
|
||||
}
|
||||
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Error("sora.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}(result, account, userAgent, clientIP)
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||
}
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 404:
|
||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||
}
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||
if headers != nil {
|
||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||
return rayID, mitigated, contentType
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||
lower := strings.ToLower(message)
|
||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||
require.NotEmpty(t, hash3)
|
||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||
}
|
||||
|
||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号",
|
||||
errType: "upstream_error",
|
||||
message: `upstream returned "invalid" payload`,
|
||||
},
|
||||
{
|
||||
name: "包含换行和制表符",
|
||||
errType: "rate_limit_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠",
|
||||
errType: "upstream_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||
require.Equal(t, tt.errType, errorObj["type"])
|
||||
require.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
headers.Set("content-type", "text/html")
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||
require.Equal(t, "challenge", mitigated)
|
||||
require.Equal(t, "text/html", contentType)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
BetaContext1M = "context-1m-2025-08-07"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
@@ -77,6 +78,12 @@ var DefaultModels = []Model{
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
CreatedAt: "2026-02-06T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
CreatedAt: "2026-02-18T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
|
||||
@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
q := r.client.Account.Query()
|
||||
|
||||
if platform != "" {
|
||||
@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
if search != "" {
|
||||
q = q.Where(dbaccount.NameContainsFold(search))
|
||||
}
|
||||
if groupID > 0 {
|
||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
if strings.TrimSpace(clientID) != "" {
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||
}
|
||||
|
||||
clientIDs := []string{
|
||||
openai.ClientID,
|
||||
openai.SoraClientID,
|
||||
}
|
||||
seen := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
for _, clientID := range clientIDs {
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clientID] = struct{}{}
|
||||
|
||||
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err == nil {
|
||||
return tokenResp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.ClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "invalid_grant")
|
||||
return
|
||||
}
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID != customClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_size,
|
||||
media_type,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
imageSize,
|
||||
mediaType,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
BillingType: int8(billingType),
|
||||
Stream: stream,
|
||||
ImageCount: imageCount,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
|
||||
@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_count": 0,
|
||||
"image_size": null,
|
||||
"media_type": null,
|
||||
"cache_ttl_overridden": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"user_agent": null
|
||||
}
|
||||
@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
}
|
||||
allowedSet[origin] = struct{}{}
|
||||
}
|
||||
allowHeaders := []string{
|
||||
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
|
||||
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
|
||||
}
|
||||
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
|
||||
openAIProperties := []string{
|
||||
"lang", "package-version", "os", "arch", "retry-count", "runtime",
|
||||
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
|
||||
}
|
||||
for _, prop := range openAIProperties {
|
||||
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
allowHeadersValue := strings.Join(allowHeaders, ", ")
|
||||
|
||||
return func(c *gin.Context) {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
if allowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if originAllowed {
|
||||
|
||||
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
||||
registerSoraOAuthRoutes(admin, h)
|
||||
|
||||
// Gemini OAuth
|
||||
registerGeminiOAuthRoutes(admin, h)
|
||||
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
sora := admin.Group("/sora")
|
||||
{
|
||||
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
||||
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
gemini := admin.Group("/gemini")
|
||||
{
|
||||
@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
// Sora Chat Completions
|
||||
soraGateway := r.Group("/v1")
|
||||
soraGateway.Use(soraBodyLimit)
|
||||
soraGateway.Use(clientRequestID)
|
||||
soraGateway.Use(opsErrorLogger)
|
||||
soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
{
|
||||
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
// 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
|
||||
@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
|
||||
func (a *Account) IsCacheTTLOverrideEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
|
||||
// 返回 "5m" 或 "1h",默认 "5m"
|
||||
func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
if a.Extra == nil {
|
||||
return "5m"
|
||||
}
|
||||
if v, ok := a.Extra["cache_ttl_override_target"]; ok {
|
||||
if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
|
||||
return target
|
||||
}
|
||||
}
|
||||
return "5m"
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
@@ -35,7 +35,7 @@ type AccountRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,13 +12,17 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -32,6 +36,10 @@ const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
|
||||
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
|
||||
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -39,6 +47,9 @@ type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
@@ -50,8 +61,13 @@ type AccountTestService struct {
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
soraTestGuardMu sync.Mutex
|
||||
soraTestLastRun map[int64]time.Time
|
||||
soraTestCooldown time.Duration
|
||||
}
|
||||
|
||||
const defaultSoraTestCooldown = 10 * time.Second
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
@@ -66,6 +82,8 @@ func NewAccountTestService(
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
soraTestLastRun: make(map[int64]time.Time),
|
||||
soraTestCooldown: defaultSoraTestCooldown,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -467,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
type soraProbeStep struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
HTTPStatus int `json:"http_status,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type soraProbeSummary struct {
|
||||
Status string `json:"status"`
|
||||
Steps []soraProbeStep `json:"steps"`
|
||||
}
|
||||
|
||||
type soraProbeRecorder struct {
|
||||
steps []soraProbeStep
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
|
||||
r.steps = append(r.steps, soraProbeStep{
|
||||
Name: name,
|
||||
Status: status,
|
||||
HTTPStatus: httpStatus,
|
||||
ErrorCode: strings.TrimSpace(errorCode),
|
||||
Message: strings.TrimSpace(message),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) finalize() soraProbeSummary {
|
||||
meSuccess := false
|
||||
partial := false
|
||||
for _, step := range r.steps {
|
||||
if step.Name == "me" {
|
||||
meSuccess = strings.EqualFold(step.Status, "success")
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(step.Status, "failed") {
|
||||
partial = true
|
||||
}
|
||||
}
|
||||
|
||||
status := "success"
|
||||
if !meSuccess {
|
||||
status = "failed"
|
||||
} else if partial {
|
||||
status = "partial_success"
|
||||
}
|
||||
|
||||
return soraProbeSummary{
|
||||
Status: status,
|
||||
Steps: append([]soraProbeStep(nil), r.steps...),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
summary := rec.finalize()
|
||||
code := ""
|
||||
for _, step := range summary.Steps {
|
||||
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
|
||||
code = step.ErrorCode
|
||||
break
|
||||
}
|
||||
}
|
||||
s.sendEvent(c, TestEvent{
|
||||
Type: "sora_test_result",
|
||||
Status: summary.Status,
|
||||
Code: code,
|
||||
Data: summary,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
|
||||
if accountID <= 0 {
|
||||
return 0, true
|
||||
}
|
||||
s.soraTestGuardMu.Lock()
|
||||
defer s.soraTestGuardMu.Unlock()
|
||||
|
||||
if s.soraTestLastRun == nil {
|
||||
s.soraTestLastRun = make(map[int64]time.Time)
|
||||
}
|
||||
cooldown := s.soraTestCooldown
|
||||
if cooldown <= 0 {
|
||||
cooldown = defaultSoraTestCooldown
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
|
||||
elapsed := now.Sub(lastRun)
|
||||
if elapsed < cooldown {
|
||||
return cooldown - elapsed, false
|
||||
}
|
||||
}
|
||||
s.soraTestLastRun[accountID] = now
|
||||
return 0, true
|
||||
}
|
||||
|
||||
func ceilSeconds(d time.Duration) int {
|
||||
if d <= 0 {
|
||||
return 1
|
||||
}
|
||||
sec := int(d / time.Second)
|
||||
if d%time.Second != 0 {
|
||||
sec++
|
||||
}
|
||||
if sec < 1 {
|
||||
sec = 1
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||
ctx := c.Request.Context()
|
||||
recorder := &soraProbeRecorder{}
|
||||
|
||||
authToken := account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
@@ -484,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
@@ -496,15 +639,21 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "network_error", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -512,8 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
|
||||
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
|
||||
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
|
||||
case strings.TrimSpace(upstreamMessage) != "":
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
|
||||
default:
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||
}
|
||||
}
|
||||
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
|
||||
|
||||
// 解析 /me 响应,提取用户信息
|
||||
var meResp map[string]any
|
||||
@@ -531,10 +705,384 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
||||
}
|
||||
|
||||
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
||||
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
||||
if err == nil {
|
||||
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
||||
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
subReq.Header.Set("Accept", "application/json")
|
||||
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if subErr != nil {
|
||||
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||
} else {
|
||||
subBody, _ := io.ReadAll(subResp.Body)
|
||||
_ = subResp.Body.Close()
|
||||
if subResp.StatusCode == http.StatusOK {
|
||||
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
|
||||
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||
}
|
||||
} else {
|
||||
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
|
||||
} else {
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
|
||||
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
|
||||
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testSora2Capabilities(
|
||||
c *gin.Context,
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
authToken string,
|
||||
proxyURL string,
|
||||
enableTLSFingerprint bool,
|
||||
recorder *soraProbeRecorder,
|
||||
) {
|
||||
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraInviteMineURL,
|
||||
proxyURL,
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
|
||||
if inviteStatus == http.StatusUnauthorized {
|
||||
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraBootstrapURL,
|
||||
proxyURL,
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
|
||||
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraInviteMineURL,
|
||||
proxyURL,
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
} else if recorder != nil {
|
||||
code := ""
|
||||
msg := ""
|
||||
if bootstrapErr != nil {
|
||||
code = "network_error"
|
||||
msg = bootstrapErr.Error()
|
||||
}
|
||||
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if inviteStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
|
||||
}
|
||||
|
||||
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
|
||||
}
|
||||
|
||||
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraRemainingURL,
|
||||
proxyURL,
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if remainingErr != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
|
||||
return
|
||||
}
|
||||
if remainingStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
|
||||
}
|
||||
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) fetchSoraTestEndpoint(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
authToken string,
|
||||
url string,
|
||||
proxyURL string,
|
||||
enableTLSFingerprint bool,
|
||||
) (int, http.Header, []byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return resp.StatusCode, resp.Header, nil, readErr
|
||||
}
|
||||
return resp.StatusCode, resp.Header, body, nil
|
||||
}
|
||||
|
||||
func parseSoraSubscriptionSummary(body []byte) string {
|
||||
var subResp struct {
|
||||
Data []struct {
|
||||
Plan struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
} `json:"plan"`
|
||||
EndTS string `json:"end_ts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &subResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
if len(subResp.Data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
first := subResp.Data[0]
|
||||
parts := make([]string, 0, 3)
|
||||
if first.Plan.Title != "" {
|
||||
parts = append(parts, first.Plan.Title)
|
||||
}
|
||||
if first.Plan.ID != "" {
|
||||
parts = append(parts, first.Plan.ID)
|
||||
}
|
||||
if first.EndTS != "" {
|
||||
parts = append(parts, "end="+first.EndTS)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "Subscription: " + strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func parseSoraInviteSummary(body []byte) string {
|
||||
var inviteResp struct {
|
||||
InviteCode string `json:"invite_code"`
|
||||
RedeemedCount int64 `json:"redeemed_count"`
|
||||
TotalCount int64 `json:"total_count"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &inviteResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := []string{"Sora2: supported"}
|
||||
if inviteResp.InviteCode != "" {
|
||||
parts = append(parts, "invite="+inviteResp.InviteCode)
|
||||
}
|
||||
if inviteResp.TotalCount > 0 {
|
||||
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
|
||||
}
|
||||
return strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func parseSoraRemainingSummary(body []byte) string {
|
||||
var remainingResp struct {
|
||||
RateLimitAndCreditBalance struct {
|
||||
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
|
||||
RateLimitReached bool `json:"rate_limit_reached"`
|
||||
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
|
||||
} `json:"rate_limit_and_credit_balance"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &remainingResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
info := remainingResp.RateLimitAndCreditBalance
|
||||
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
|
||||
if info.RateLimitReached {
|
||||
parts = append(parts, "rate_limited=true")
|
||||
}
|
||||
if info.AccessResetsInSeconds > 0 {
|
||||
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
|
||||
}
|
||||
return strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return !s.cfg.Sora.Client.DisableTLSFingerprint
|
||||
}
|
||||
|
||||
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
return soraerror.ExtractCloudflareRayID(headers, body)
|
||||
}
|
||||
|
||||
func extractSoraEgressIPHint(headers http.Header) string {
|
||||
if headers == nil {
|
||||
return "unknown"
|
||||
}
|
||||
candidates := []string{
|
||||
"x-openai-public-ip",
|
||||
"x-envoy-external-address",
|
||||
"cf-connecting-ip",
|
||||
"x-forwarded-for",
|
||||
}
|
||||
for _, key := range candidates {
|
||||
if value := strings.TrimSpace(headers.Get(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func sanitizeProxyURLForLog(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "<invalid_proxy_url>"
|
||||
}
|
||||
if u.User != nil {
|
||||
u.User = nil
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func endpointPathForLog(endpoint string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(endpoint))
|
||||
if err != nil || parsed.Path == "" {
|
||||
return endpoint
|
||||
}
|
||||
return parsed.Path
|
||||
}
|
||||
|
||||
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
|
||||
accountID := int64(0)
|
||||
platform := ""
|
||||
proxyID := "none"
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
platform = account.Platform
|
||||
if account.ProxyID != nil {
|
||||
proxyID = fmt.Sprintf("%d", *account.ProxyID)
|
||||
}
|
||||
}
|
||||
cfRay := extractCloudflareRayID(headers, body)
|
||||
if cfRay == "" {
|
||||
cfRay = "unknown"
|
||||
}
|
||||
log.Printf(
|
||||
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
|
||||
accountID,
|
||||
platform,
|
||||
endpoint,
|
||||
endpointPathForLog(endpoint),
|
||||
proxyID,
|
||||
sanitizeProxyURLForLog(proxyURL),
|
||||
cfRay,
|
||||
extractSoraEgressIPHint(headers),
|
||||
)
|
||||
}
|
||||
|
||||
func truncateSoraErrorBody(body []byte, max int) string {
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
|
||||
319
backend/internal/service/account_test_service_sora_test.go
Normal file
319
backend/internal/service/account_test_service_sora_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type queuedHTTPUpstream struct {
|
||||
responses []*http.Response
|
||||
requests []*http.Request
|
||||
tlsFlags []bool
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected Do call")
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
u.requests = append(u.requests, req)
|
||||
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
|
||||
if len(u.responses) == 0 {
|
||||
return nil, fmt.Errorf("no mocked response")
|
||||
}
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func newJSONResponse(status int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||
resp := newJSONResponse(status, body)
|
||||
resp.Header.Set(key, value)
|
||||
return resp
|
||||
}
|
||||
|
||||
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
|
||||
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
TLSFingerprint: config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
DisableTLSFingerprint: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
|
||||
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"test_start"`)
|
||||
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
||||
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||
require.Contains(t, body, "Subscription check returned 403")
|
||||
require.Contains(t, body, "Sora2 invite check returned 401")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"partial_success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "HTTP 429")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "token_invalidated")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.Contains(t, body, "token_invalidated")
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
soraTestCooldown: time.Hour,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c1, _ := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c1, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
c2, rec2 := newSoraTestContext()
|
||||
err = svc.testSoraAccountConnection(c2, account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "测试过于频繁")
|
||||
body := rec2.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"code":"test_rate_limited"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestSanitizeProxyURLForLog(t *testing.T) {
|
||||
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
||||
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
||||
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
||||
}
|
||||
|
||||
func TestExtractSoraEgressIPHint(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("x-openai-public-ip", "203.0.113.10")
|
||||
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
||||
|
||||
h2 := make(http.Header)
|
||||
h2.Set("x-envoy-external-address", "198.51.100.9")
|
||||
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
||||
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
||||
}
|
||||
@@ -4,11 +4,15 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
)
|
||||
|
||||
// AdminService interface defines admin management operations
|
||||
@@ -39,7 +43,7 @@ type AdminService interface {
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
@@ -65,6 +69,7 @@ type AdminService interface {
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
|
||||
|
||||
// Redeem code management
|
||||
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
|
||||
@@ -288,6 +293,32 @@ type ProxyTestResult struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyQualityCheckResult struct {
|
||||
ProxyID int64 `json:"proxy_id"`
|
||||
Score int `json:"score"`
|
||||
Grade string `json:"grade"`
|
||||
Summary string `json:"summary"`
|
||||
ExitIP string `json:"exit_ip,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
BaseLatencyMs int64 `json:"base_latency_ms,omitempty"`
|
||||
PassedCount int `json:"passed_count"`
|
||||
WarnCount int `json:"warn_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
ChallengeCount int `json:"challenge_count"`
|
||||
CheckedAt int64 `json:"checked_at"`
|
||||
Items []ProxyQualityCheckItem `json:"items"`
|
||||
}
|
||||
|
||||
type ProxyQualityCheckItem struct {
|
||||
Target string `json:"target"`
|
||||
Status string `json:"status"` // pass/warn/fail/challenge
|
||||
HTTPStatus int `json:"http_status,omitempty"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
CFRay string `json:"cf_ray,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||
type ProxyExitInfo struct {
|
||||
IP string
|
||||
@@ -302,6 +333,58 @@ type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
type proxyQualityTarget struct {
|
||||
Target string
|
||||
URL string
|
||||
Method string
|
||||
AllowedStatuses map[int]struct{}
|
||||
}
|
||||
|
||||
var proxyQualityTargets = []proxyQualityTarget{
|
||||
{
|
||||
Target: "openai",
|
||||
URL: "https://api.openai.com/v1/models",
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
Target: "anthropic",
|
||||
URL: "https://api.anthropic.com/v1/messages",
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
http.StatusMethodNotAllowed: {},
|
||||
http.StatusNotFound: {},
|
||||
http.StatusBadRequest: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
Target: "gemini",
|
||||
URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta",
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusOK: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
Target: "sora",
|
||||
URL: "https://sora.chatgpt.com/backend/me",
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
proxyQualityRequestTimeout = 15 * time.Second
|
||||
proxyQualityResponseHeaderTimeout = 10 * time.Second
|
||||
proxyQualityMaxBodyBytes = int64(8 * 1024)
|
||||
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
@@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &ProxyQualityCheckResult{
|
||||
ProxyID: id,
|
||||
Score: 100,
|
||||
Grade: "A",
|
||||
CheckedAt: time.Now().Unix(),
|
||||
Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1),
|
||||
}
|
||||
|
||||
proxyURL := proxy.URL()
|
||||
if s.proxyProber == nil {
|
||||
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||
Target: "base_connectivity",
|
||||
Status: "fail",
|
||||
Message: "代理探测服务未配置",
|
||||
})
|
||||
result.FailedCount++
|
||||
finalizeProxyQualityResult(result)
|
||||
s.saveProxyQualitySnapshot(ctx, id, result, nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||
if err != nil {
|
||||
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||
Target: "base_connectivity",
|
||||
Status: "fail",
|
||||
LatencyMs: latencyMs,
|
||||
Message: err.Error(),
|
||||
})
|
||||
result.FailedCount++
|
||||
finalizeProxyQualityResult(result)
|
||||
s.saveProxyQualitySnapshot(ctx, id, result, nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.ExitIP = exitInfo.IP
|
||||
result.Country = exitInfo.Country
|
||||
result.CountryCode = exitInfo.CountryCode
|
||||
result.BaseLatencyMs = latencyMs
|
||||
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||
Target: "base_connectivity",
|
||||
Status: "pass",
|
||||
LatencyMs: latencyMs,
|
||||
Message: "代理出口连通正常",
|
||||
})
|
||||
result.PassedCount++
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: proxyQualityRequestTimeout,
|
||||
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
|
||||
ProxyStrict: true,
|
||||
})
|
||||
if err != nil {
|
||||
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||
Target: "http_client",
|
||||
Status: "fail",
|
||||
Message: fmt.Sprintf("创建检测客户端失败: %v", err),
|
||||
})
|
||||
result.FailedCount++
|
||||
finalizeProxyQualityResult(result)
|
||||
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, target := range proxyQualityTargets {
|
||||
item := runProxyQualityTarget(ctx, client, target)
|
||||
result.Items = append(result.Items, item)
|
||||
switch item.Status {
|
||||
case "pass":
|
||||
result.PassedCount++
|
||||
case "warn":
|
||||
result.WarnCount++
|
||||
case "challenge":
|
||||
result.ChallengeCount++
|
||||
default:
|
||||
result.FailedCount++
|
||||
}
|
||||
}
|
||||
|
||||
finalizeProxyQualityResult(result)
|
||||
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem {
|
||||
item := ProxyQualityCheckItem{
|
||||
Target: target.Target,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil)
|
||||
if err != nil {
|
||||
item.Status = "fail"
|
||||
item.Message = fmt.Sprintf("构建请求失败: %v", err)
|
||||
return item
|
||||
}
|
||||
req.Header.Set("Accept", "application/json,text/html,*/*")
|
||||
req.Header.Set("User-Agent", proxyQualityClientUserAgent)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
item.Status = "fail"
|
||||
item.LatencyMs = time.Since(start).Milliseconds()
|
||||
item.Message = fmt.Sprintf("请求失败: %v", err)
|
||||
return item
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
item.LatencyMs = time.Since(start).Milliseconds()
|
||||
item.HTTPStatus = resp.StatusCode
|
||||
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1))
|
||||
if readErr != nil {
|
||||
item.Status = "fail"
|
||||
item.Message = fmt.Sprintf("读取响应失败: %v", readErr)
|
||||
return item
|
||||
}
|
||||
if int64(len(body)) > proxyQualityMaxBodyBytes {
|
||||
body = body[:proxyQualityMaxBodyBytes]
|
||||
}
|
||||
|
||||
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
item.Status = "challenge"
|
||||
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
|
||||
item.Message = "Sora 命中 Cloudflare challenge"
|
||||
return item
|
||||
}
|
||||
|
||||
if _, ok := target.AllowedStatuses[resp.StatusCode]; ok {
|
||||
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
|
||||
item.Status = "pass"
|
||||
item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
} else {
|
||||
item.Status = "warn"
|
||||
item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
item.Status = "warn"
|
||||
item.Message = "目标返回 429,可能存在频控"
|
||||
return item
|
||||
}
|
||||
|
||||
item.Status = "fail"
|
||||
item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode)
|
||||
return item
|
||||
}
|
||||
|
||||
func finalizeProxyQualityResult(result *ProxyQualityCheckResult) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30
|
||||
if score < 0 {
|
||||
score = 0
|
||||
}
|
||||
result.Score = score
|
||||
result.Grade = proxyQualityGrade(score)
|
||||
result.Summary = fmt.Sprintf(
|
||||
"通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项",
|
||||
result.PassedCount,
|
||||
result.WarnCount,
|
||||
result.FailedCount,
|
||||
result.ChallengeCount,
|
||||
)
|
||||
}
|
||||
|
||||
func proxyQualityGrade(score int) string {
|
||||
switch {
|
||||
case score >= 90:
|
||||
return "A"
|
||||
case score >= 75:
|
||||
return "B"
|
||||
case score >= 60:
|
||||
return "C"
|
||||
case score >= 40:
|
||||
return "D"
|
||||
default:
|
||||
return "F"
|
||||
}
|
||||
}
|
||||
|
||||
func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
if result.ChallengeCount > 0 {
|
||||
return "challenge"
|
||||
}
|
||||
if result.FailedCount > 0 {
|
||||
return "failed"
|
||||
}
|
||||
if result.WarnCount > 0 {
|
||||
return "warn"
|
||||
}
|
||||
if result.PassedCount > 0 {
|
||||
return "healthy"
|
||||
}
|
||||
return "failed"
|
||||
}
|
||||
|
||||
func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
for _, item := range result.Items {
|
||||
if item.CFRay != "" {
|
||||
return item.CFRay
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range result.Items {
|
||||
if item.Target == "base_connectivity" {
|
||||
return item.Status == "pass"
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
score := result.Score
|
||||
checkedAt := result.CheckedAt
|
||||
info := &ProxyLatencyInfo{
|
||||
Success: proxyQualityBaseConnectivityPass(result),
|
||||
Message: result.Summary,
|
||||
QualityStatus: proxyQualityOverallStatus(result),
|
||||
QualityScore: &score,
|
||||
QualityGrade: result.Grade,
|
||||
QualitySummary: result.Summary,
|
||||
QualityCheckedAt: &checkedAt,
|
||||
QualityCFRay: proxyQualityFirstCFRay(result),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if result.BaseLatencyMs > 0 {
|
||||
latency := result.BaseLatencyMs
|
||||
info.LatencyMs = &latency
|
||||
}
|
||||
if exitInfo != nil {
|
||||
info.IPAddress = exitInfo.IP
|
||||
info.Country = exitInfo.Country
|
||||
info.CountryCode = exitInfo.CountryCode
|
||||
info.Region = exitInfo.Region
|
||||
info.City = exitInfo.City
|
||||
}
|
||||
s.saveProxyLatency(ctx, proxyID, info)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||
if s.proxyProber == nil || proxy == nil {
|
||||
return
|
||||
@@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
|
||||
proxies[i].CountryCode = info.CountryCode
|
||||
proxies[i].Region = info.Region
|
||||
proxies[i].City = info.City
|
||||
proxies[i].QualityStatus = info.QualityStatus
|
||||
proxies[i].QualityScore = info.QualityScore
|
||||
proxies[i].QualityGrade = info.QualityGrade
|
||||
proxies[i].QualitySummary = info.QualitySummary
|
||||
proxies[i].QualityChecked = info.QualityCheckedAt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
|
||||
if s.proxyLatencyCache == nil || info == nil {
|
||||
return
|
||||
}
|
||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||
|
||||
merged := *info
|
||||
if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil {
|
||||
if existing := latencies[proxyID]; existing != nil {
|
||||
if merged.QualityCheckedAt == nil &&
|
||||
merged.QualityScore == nil &&
|
||||
merged.QualityGrade == "" &&
|
||||
merged.QualityStatus == "" &&
|
||||
merged.QualitySummary == "" &&
|
||||
merged.QualityCFRay == "" {
|
||||
merged.QualityStatus = existing.QualityStatus
|
||||
merged.QualityScore = existing.QualityScore
|
||||
merged.QualityGrade = existing.QualityGrade
|
||||
merged.QualitySummary = existing.QualitySummary
|
||||
merged.QualityCheckedAt = existing.QualityCheckedAt
|
||||
merged.QualityCFRay = existing.QualityCFRay
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil {
|
||||
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
95
backend/internal/service/admin_service_proxy_quality_test.go
Normal file
95
backend/internal/service/admin_service_proxy_quality_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
|
||||
result := &ProxyQualityCheckResult{
|
||||
PassedCount: 2,
|
||||
WarnCount: 1,
|
||||
FailedCount: 1,
|
||||
ChallengeCount: 1,
|
||||
}
|
||||
|
||||
finalizeProxyQualityResult(result)
|
||||
|
||||
require.Equal(t, 38, result.Score)
|
||||
require.Equal(t, "F", result.Grade)
|
||||
require.Contains(t, result.Summary, "通过 2 项")
|
||||
require.Contains(t, result.Summary, "告警 1 项")
|
||||
require.Contains(t, result.Summary, "失败 1 项")
|
||||
require.Contains(t, result.Summary, "挑战 1 项")
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("cf-ray", "test-ray-123")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "sora",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "challenge", item.Status)
|
||||
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
|
||||
require.Equal(t, "test-ray-123", item.CFRay)
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"models":[]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "gemini",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusOK: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "pass", item.Status)
|
||||
require.Equal(t, http.StatusOK, item.HTTPStatus)
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "openai",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "warn", item.Status)
|
||||
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
|
||||
require.Contains(t, item.Message, "目标可达")
|
||||
}
|
||||
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
@@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation5mTokens = int(v)
|
||||
}
|
||||
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation1hTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
@@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation5mTokens = int(v)
|
||||
}
|
||||
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation1hTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ type ModelPricing struct {
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
}
|
||||
|
||||
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
// 启用 5m/1h 分类计费的条件:
|
||||
// 1. 存在 1h 价格
|
||||
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
SupportsCacheBreakdown: false,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
|
||||
@@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
|
||||
InputPricePerToken: 3e-6,
|
||||
OutputPricePerToken: 15e-6,
|
||||
SupportsCacheBreakdown: true,
|
||||
CacheCreation5mPrice: 4.0, // per million tokens
|
||||
CacheCreation1hPrice: 5.0, // per million tokens
|
||||
CacheCreation5mPrice: 4e-6, // per token
|
||||
CacheCreation1hPrice: 5e-6, // per token
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected5m := float64(100000) / 1_000_000 * 4.0
|
||||
expected1h := float64(50000) / 1_000_000 * 5.0
|
||||
expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
|
||||
expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
|
||||
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||
)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||
}
|
||||
|
||||
func TestStripBetaToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
token string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "token in middle",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token at start",
|
||||
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token at end",
|
||||
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token not present",
|
||||
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "empty header",
|
||||
header: "",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "only token",
|
||||
header: "context-1m-2025-08-07",
|
||||
token: "context-1m-2025-08-07",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripBetaToken(tt.header, tt.token)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
|
||||
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
||||
incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20"
|
||||
drop := map[string]struct{}{"context-1m-2025-08-07": {}}
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@@ -349,6 +349,8 @@ type ClaudeUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||
}
|
||||
|
||||
// ForwardResult 转发结果
|
||||
@@ -373,9 +375,10 @@ type ForwardResult struct {
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
@@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// messages requests typically use only oauth + interleaved-thinking.
|
||||
// Also drop claude-code beta if a downstream client added it.
|
||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
|
||||
drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||
} else {
|
||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
|
||||
req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
@@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
// stripBetaToken removes a single beta token from a comma-separated header value.
|
||||
// It short-circuits when the token is not present to avoid unnecessary allocations.
|
||||
func stripBetaToken(header, token string) string {
|
||||
if !strings.Contains(header, token) {
|
||||
return header
|
||||
}
|
||||
out := make([]string, 0, 8)
|
||||
for _, p := range strings.Split(header, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" || p == token {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
||||
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
||||
// headers when using Claude Code-scoped OAuth credentials.
|
||||
@@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
rewriteCacheCreationJSON(u, overrideTarget)
|
||||
}
|
||||
}
|
||||
}
|
||||
if eventType == "message_delta" {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
rewriteCacheCreationJSON(u, overrideTarget)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if needModelReplace {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||
@@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||
@@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() && cc5m.Int() > 0 {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
}
|
||||
if cc1h.Exists() && cc1h.Int() > 0 {
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
|
||||
// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。
|
||||
func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
|
||||
// Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别
|
||||
if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 {
|
||||
usage.CacheCreation5mTokens = usage.CacheCreationInputTokens
|
||||
}
|
||||
|
||||
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if total == 0 {
|
||||
return false
|
||||
}
|
||||
switch target {
|
||||
case "1h":
|
||||
if usage.CacheCreation1hTokens == total {
|
||||
return false // 已经全是 1h
|
||||
}
|
||||
usage.CacheCreation1hTokens = total
|
||||
usage.CacheCreation5mTokens = 0
|
||||
default: // "5m"
|
||||
if usage.CacheCreation5mTokens == total {
|
||||
return false // 已经全是 5m
|
||||
}
|
||||
usage.CacheCreation5mTokens = total
|
||||
usage.CacheCreation1hTokens = 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
|
||||
// usageObj 是 usage JSON 对象(map[string]any)。
|
||||
func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
|
||||
ccObj, ok := usageObj["cache_creation"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
|
||||
v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
|
||||
total := v5m + v1h
|
||||
if total == 0 {
|
||||
return
|
||||
}
|
||||
switch target {
|
||||
case "1h":
|
||||
ccObj["ephemeral_1h_input_tokens"] = total
|
||||
ccObj["ephemeral_5m_input_tokens"] = float64(0)
|
||||
default: // "5m"
|
||||
ccObj["ephemeral_5m_input_tokens"] = total
|
||||
ccObj["ephemeral_1h_input_tokens"] = float64(0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
@@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
@@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
@@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
@@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
@@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
MediaType: mediaType,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
@@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
@@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
@@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
@@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
@@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
incomingBeta := req.Header.Get("anthropic-beta")
|
||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
|
||||
drop := map[string]struct{}{claude.BetaContext1M: {}}
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||
} else {
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
if clientBetaHeader == "" {
|
||||
@@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||
beta = beta + "," + claude.BetaTokenCounting
|
||||
}
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
|
||||
}
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
|
||||
@@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
|
||||
require.Equal(t, 60, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
// 先在 message_start 中写入非零 5m/1h 明细
|
||||
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage)
|
||||
require.Equal(t, 30, usage.CacheCreation5mTokens)
|
||||
require.Equal(t, 70, usage.CacheCreation1hTokens)
|
||||
|
||||
// 后续 delta 带默认 0,不应覆盖已有非零值
|
||||
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage)
|
||||
require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细")
|
||||
require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细")
|
||||
require.Equal(t, 12, usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
|
||||
svc := newMinimalGatewayService()
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
|
||||
@@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
delete(reqBody, "max_output_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||
delete(reqBody, "max_completion_tokens")
|
||||
result.Modified = true
|
||||
// Strip parameters unsupported by codex models via the Responses API.
|
||||
for _, key := range []string{
|
||||
"max_output_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
} {
|
||||
if _, ok := reqBody[key]; ok {
|
||||
delete(reqBody, key)
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
if normalizeCodexTools(reqBody) {
|
||||
|
||||
@@ -2,13 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
State string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
if !ok {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||
}
|
||||
if input.State == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
|
||||
}
|
||||
|
||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||
proxyURL := session.ProxyURL
|
||||
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
||||
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
client := newOpenAIOAuthHTTPClient(proxyURL)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var sessionResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
Expires string `json:"expires"`
|
||||
User struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||
if strings.TrimSpace(sessionResp.Expires) != "" {
|
||||
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
||||
expiresAt = parsed.Unix()
|
||||
}
|
||||
}
|
||||
expiresIn := expiresAt - time.Now().Unix()
|
||||
if expiresIn < 0 {
|
||||
expiresIn = 0
|
||||
}
|
||||
|
||||
return &OpenAITokenInfo{
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||
}
|
||||
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
clientID := account.GetCredential("client_id")
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||
if proxyID == nil {
|
||||
return "", nil
|
||||
}
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err != nil {
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
return "", nil
|
||||
}
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
|
||||
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||
transport := &http.Transport{}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientNoopStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
require.Equal(t, "demo@example.com", info.Email)
|
||||
require.Greater(t, info.ExpiresAt, int64(0))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientStateStub struct {
|
||||
exchangeCalled int32
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||
return &openai.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "oauth state is required")
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
State: "wrong-state",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid oauth state")
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
State: "expected-state",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at", info.AccessToken)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||
|
||||
_, ok := svc.sessionStore.Get("sid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
|
||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: opsAccountsPageSize,
|
||||
}, platformFilter, "", "", "")
|
||||
}, platformFilter, "", "", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -28,14 +28,15 @@ var (
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
@@ -46,14 +47,15 @@ type PricingRemoteClient interface {
|
||||
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
type LiteLLMRawEntry struct {
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.CacheCreationInputTokenCost != nil {
|
||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||
}
|
||||
if entry.CacheCreationInputTokenCostAbove1hr != nil {
|
||||
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
|
||||
}
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
|
||||
@@ -40,6 +40,11 @@ type ProxyWithAccountCount struct {
|
||||
CountryCode string
|
||||
Region string
|
||||
City string
|
||||
QualityStatus string
|
||||
QualityScore *int
|
||||
QualityGrade string
|
||||
QualitySummary string
|
||||
QualityChecked *int64
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
|
||||
@@ -6,15 +6,21 @@ import (
|
||||
)
|
||||
|
||||
type ProxyLatencyInfo struct {
|
||||
Success bool `json:"success"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Success bool `json:"success"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"`
|
||||
QualityCFRay string `json:"quality_cf_ray,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyLatencyCache interface {
|
||||
|
||||
@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
||||
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
|
||||
windowEnd := result.resetAt
|
||||
if result.fiveHourReset != nil {
|
||||
windowEnd = *result.fiveHourReset
|
||||
}
|
||||
windowStart := windowEnd.Add(-5 * time.Hour)
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
|
||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
// 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
if resetTimestamp == "" {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
|
||||
return nil
|
||||
}
|
||||
|
||||
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
|
||||
type anthropic429Result struct {
|
||||
resetAt time.Time // The correct reset time to use for SetRateLimited
|
||||
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
|
||||
}
|
||||
|
||||
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
|
||||
// to determine which window (5h or 7d) actually triggered the 429.
|
||||
//
|
||||
// Headers used:
|
||||
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
|
||||
// - anthropic-ratelimit-unified-5h-reset
|
||||
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
|
||||
// - anthropic-ratelimit-unified-7d-reset
|
||||
//
|
||||
// Returns nil when the per-window headers are absent (caller should fall back to
|
||||
// the aggregated anthropic-ratelimit-unified-reset header).
|
||||
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
|
||||
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
|
||||
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
|
||||
|
||||
if reset5hStr == "" && reset7dStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var reset5h, reset7d *time.Time
|
||||
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
reset5h = &t
|
||||
}
|
||||
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
reset7d = &t
|
||||
}
|
||||
|
||||
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
|
||||
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
|
||||
|
||||
slog.Info("anthropic_429_window_analysis",
|
||||
"is_5h_exceeded", is5hExceeded,
|
||||
"is_7d_exceeded", is7dExceeded,
|
||||
"reset_5h", reset5hStr,
|
||||
"reset_7d", reset7dStr,
|
||||
)
|
||||
|
||||
// Select the correct reset time based on which window(s) are exceeded.
|
||||
var chosen *time.Time
|
||||
switch {
|
||||
case is5hExceeded && is7dExceeded:
|
||||
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
|
||||
chosen = reset7d
|
||||
if chosen == nil {
|
||||
chosen = reset5h
|
||||
}
|
||||
case is5hExceeded:
|
||||
chosen = reset5h
|
||||
case is7dExceeded:
|
||||
chosen = reset7d
|
||||
default:
|
||||
// Neither flag clearly exceeded — pick the sooner reset as best guess
|
||||
chosen = pickSooner(reset5h, reset7d)
|
||||
}
|
||||
|
||||
if chosen == nil {
|
||||
return nil
|
||||
}
|
||||
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
|
||||
}
|
||||
|
||||
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
|
||||
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
|
||||
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
|
||||
prefix := "anthropic-ratelimit-unified-" + window + "-"
|
||||
|
||||
// Check surpassed-threshold first (most explicit signal)
|
||||
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fall back to utilization >= 1.0
|
||||
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
|
||||
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
|
||||
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pickSooner returns whichever of the two time pointers is earlier.
|
||||
// If only one is non-nil, it is returned. If both are nil, returns nil.
|
||||
func pickSooner(a, b *time.Time) *time.Time {
|
||||
switch {
|
||||
case a != nil && b != nil:
|
||||
if a.Before(*b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
case a != nil:
|
||||
return a
|
||||
default:
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
|
||||
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
|
||||
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
|
||||
// fiveHourReset should still be populated for session window calculation
|
||||
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
|
||||
result := calculateAnthropic429ResetTime(http.Header{})
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
|
||||
if result.fiveHourReset != nil {
|
||||
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAnthropicWindowExceeded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
window string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "utilization above 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
|
||||
window: "5h",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "utilization exactly 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
|
||||
window: "5h",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "utilization below 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
|
||||
window: "5h",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold true",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
|
||||
window: "7d",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold True (case insensitive)",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
|
||||
window: "7d",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold false",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
|
||||
window: "7d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no headers",
|
||||
headers: http.Header{},
|
||||
window: "5h",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := isAnthropicWindowExceeded(tc.headers, tc.window)
|
||||
if got != tc.expected {
|
||||
t.Errorf("expected %v, got %v", tc.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// assertAnthropicResult is a test helper that verifies the result is non-nil and
|
||||
// has the expected resetAt unix timestamp.
|
||||
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return // unreachable, but satisfies staticcheck SA5011
|
||||
}
|
||||
want := time.Unix(wantUnix, 0)
|
||||
if !result.resetAt.Equal(want) {
|
||||
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func makeHeader(key, value string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(key, value)
|
||||
return h
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
260
backend/internal/service/sora_curl_cffi_sidecar.go
Normal file
260
backend/internal/service/sora_curl_cffi_sidecar.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
|
||||
|
||||
type soraCurlCFFISidecarRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers,omitempty"`
|
||||
BodyBase64 string `json:"body_base64,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Impersonate string `json:"impersonate,omitempty"`
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type soraCurlCFFISidecarResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status int `json:"status"`
|
||||
Headers map[string]any `json:"headers"`
|
||||
BodyBase64 string `json:"body_base64"`
|
||||
Body string `json:"body"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
if req == nil || req.URL == nil {
|
||||
return nil, errors.New("request url is nil")
|
||||
}
|
||||
if c == nil || c.cfg == nil {
|
||||
return nil, errors.New("sora curl_cffi sidecar config is nil")
|
||||
}
|
||||
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return nil, errors.New("sora curl_cffi sidecar is disabled")
|
||||
}
|
||||
endpoint := c.curlCFFISidecarEndpoint()
|
||||
if endpoint == "" {
|
||||
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
|
||||
}
|
||||
|
||||
bodyBytes, err := readAndRestoreRequestBody(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
|
||||
}
|
||||
|
||||
headers := make(map[string][]string, len(req.Header)+1)
|
||||
for key, vals := range req.Header {
|
||||
copied := make([]string, len(vals))
|
||||
copy(copied, vals)
|
||||
headers[key] = copied
|
||||
}
|
||||
if strings.TrimSpace(req.Host) != "" {
|
||||
if _, ok := headers["Host"]; !ok {
|
||||
headers["Host"] = []string{req.Host}
|
||||
}
|
||||
}
|
||||
|
||||
payload := soraCurlCFFISidecarRequest{
|
||||
Method: req.Method,
|
||||
URL: req.URL.String(),
|
||||
Headers: headers,
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
SessionKey: c.sidecarSessionKey(account, proxyURL),
|
||||
Impersonate: c.curlCFFIImpersonate(),
|
||||
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
|
||||
}
|
||||
if len(bodyBytes) > 0 {
|
||||
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
|
||||
}
|
||||
sidecarReq.Header.Set("Content-Type", "application/json")
|
||||
sidecarReq.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
|
||||
sidecarResp, err := httpClient.Do(sidecarReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = sidecarResp.Body.Close()
|
||||
}()
|
||||
|
||||
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
|
||||
}
|
||||
if sidecarResp.StatusCode != http.StatusOK {
|
||||
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
|
||||
}
|
||||
|
||||
var payloadResp soraCurlCFFISidecarResponse
|
||||
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
|
||||
}
|
||||
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
|
||||
}
|
||||
statusCode := payloadResp.StatusCode
|
||||
if statusCode <= 0 {
|
||||
statusCode = payloadResp.Status
|
||||
}
|
||||
if statusCode <= 0 {
|
||||
return nil, errors.New("sora curl_cffi sidecar response missing status code")
|
||||
}
|
||||
|
||||
responseBody := []byte(payloadResp.Body)
|
||||
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
|
||||
}
|
||||
responseBody = decoded
|
||||
}
|
||||
|
||||
respHeaders := make(http.Header)
|
||||
for key, rawVal := range payloadResp.Headers {
|
||||
for _, v := range convertSidecarHeaderValue(rawVal) {
|
||||
respHeaders.Add(key, v)
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Header: respHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(responseBody)),
|
||||
ContentLength: int64(len(responseBody)),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return ""
|
||||
}
|
||||
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
|
||||
return raw
|
||||
}
|
||||
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
|
||||
parsed.Path = "/request"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
return timeoutSeconds
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFIImpersonate() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return "chrome131"
|
||||
}
|
||||
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
|
||||
if impersonate == "" {
|
||||
return "chrome131"
|
||||
}
|
||||
return impersonate
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 3600
|
||||
}
|
||||
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
|
||||
if ttl < 0 {
|
||||
return 3600
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func convertSidecarHeaderValue(raw any) []string {
|
||||
switch val := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
if strings.TrimSpace(val) == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{val}
|
||||
case []any:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
s := strings.TrimSpace(fmt.Sprint(item))
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if strings.TrimSpace(item) != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
s := strings.TrimSpace(fmt.Sprint(val))
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
}
|
||||
}
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,6 +25,9 @@ import (
|
||||
const soraImageInputMaxBytes = 20 << 20
|
||||
const soraImageInputMaxRedirects = 3
|
||||
const soraImageInputTimeout = 20 * time.Second
|
||||
const soraVideoInputMaxBytes = 200 << 20
|
||||
const soraVideoInputMaxRedirects = 3
|
||||
const soraVideoInputTimeout = 60 * time.Second
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
@@ -61,6 +66,36 @@ type SoraGatewayService struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type soraWatermarkOptions struct {
|
||||
Enabled bool
|
||||
ParseMethod string
|
||||
ParseURL string
|
||||
ParseToken string
|
||||
FallbackOnFailure bool
|
||||
DeletePost bool
|
||||
}
|
||||
|
||||
type soraCharacterOptions struct {
|
||||
SetPublic bool
|
||||
DeleteAfterGenerate bool
|
||||
}
|
||||
|
||||
type soraCharacterFlowResult struct {
|
||||
CameoID string
|
||||
CharacterID string
|
||||
Username string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
|
||||
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
|
||||
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
|
||||
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
|
||||
|
||||
type soraPreflightChecker interface {
|
||||
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||
}
|
||||
|
||||
func NewSoraGatewayService(
|
||||
soraClient SoraClient,
|
||||
mediaStorage *SoraMediaStorage,
|
||||
@@ -112,29 +147,133 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||
}
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
||||
return nil, fmt.Errorf("prompt-enhance not supported")
|
||||
}
|
||||
|
||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
imageInput = strings.TrimSpace(imageInput)
|
||||
videoInput = strings.TrimSpace(videoInput)
|
||||
remixTargetID = strings.TrimSpace(remixTargetID)
|
||||
|
||||
if videoInput != "" && modelCfg.Type != "video" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
|
||||
return nil, errors.New("video input only supports video models")
|
||||
}
|
||||
if videoInput != "" && imageInput != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
|
||||
return nil, errors.New("image input and video input cannot be used together")
|
||||
}
|
||||
characterOnly := videoInput != "" && prompt == ""
|
||||
if modelCfg.Type == "prompt_enhance" && prompt == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
if strings.TrimSpace(videoInput) != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
||||
return nil, errors.New("video input not supported")
|
||||
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
|
||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
|
||||
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
}
|
||||
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
|
||||
if err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
content := strings.TrimSpace(enhancedPrompt)
|
||||
if content == "" {
|
||||
content = prompt
|
||||
}
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
|
||||
characterOpts := parseSoraCharacterOptions(reqBody)
|
||||
watermarkOpts := parseSoraWatermarkOptions(reqBody)
|
||||
var characterResult *soraCharacterFlowResult
|
||||
if videoInput != "" {
|
||||
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
|
||||
if videoErr != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
|
||||
return nil, videoErr
|
||||
}
|
||||
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
|
||||
if videoErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
|
||||
}
|
||||
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
|
||||
characterID := strings.TrimSpace(characterResult.CharacterID)
|
||||
defer func() {
|
||||
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancelCleanup()
|
||||
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
|
||||
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if characterOnly {
|
||||
content := "角色创建成功"
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
|
||||
}
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
resp := buildSoraNonStreamResponse(content, reqModel)
|
||||
if characterResult != nil {
|
||||
resp["character_id"] = characterResult.CharacterID
|
||||
resp["cameo_id"] = characterResult.CameoID
|
||||
resp["character_username"] = characterResult.Username
|
||||
resp["character_display_name"] = characterResult.DisplayName
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var imageData []byte
|
||||
imageFilename := ""
|
||||
if strings.TrimSpace(imageInput) != "" {
|
||||
if imageInput != "" {
|
||||
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||
@@ -164,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
MediaID: mediaID,
|
||||
})
|
||||
case "video":
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
})
|
||||
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
|
||||
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
|
||||
Prompt: formatSoraStoryboardPrompt(prompt),
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
})
|
||||
} else {
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||
})
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||
}
|
||||
@@ -185,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
|
||||
var mediaURLs []string
|
||||
videoGenerationID := ""
|
||||
mediaType := modelCfg.Type
|
||||
imageCount := 0
|
||||
imageSize := ""
|
||||
@@ -198,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
imageCount = len(urls)
|
||||
imageSize = soraImageSizeFromModel(reqModel)
|
||||
case "video":
|
||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
|
||||
if pollErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||
}
|
||||
mediaURLs = urls
|
||||
if videoStatus != nil {
|
||||
mediaURLs = videoStatus.URLs
|
||||
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
|
||||
}
|
||||
default:
|
||||
mediaType = "prompt"
|
||||
}
|
||||
|
||||
watermarkPostID := ""
|
||||
if modelCfg.Type == "video" && watermarkOpts.Enabled {
|
||||
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
|
||||
if watermarkErr != nil {
|
||||
if !watermarkOpts.FallbackOnFailure {
|
||||
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
|
||||
}
|
||||
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
|
||||
} else if strings.TrimSpace(watermarkURL) != "" {
|
||||
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
|
||||
watermarkPostID = strings.TrimSpace(postID)
|
||||
}
|
||||
}
|
||||
|
||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
@@ -217,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
}
|
||||
}
|
||||
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||
}
|
||||
}
|
||||
|
||||
content := buildSoraContent(mediaType, finalURLs)
|
||||
var firstTokenMs *int
|
||||
@@ -265,9 +439,270 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||
}
|
||||
|
||||
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
|
||||
opts := soraWatermarkOptions{
|
||||
Enabled: parseBoolWithDefault(body, "watermark_free", false),
|
||||
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
|
||||
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
|
||||
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
|
||||
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
|
||||
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
|
||||
}
|
||||
if opts.ParseMethod == "" {
|
||||
opts.ParseMethod = "third_party"
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||
return soraCharacterOptions{
|
||||
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
|
||||
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
|
||||
}
|
||||
}
|
||||
|
||||
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch typed := val.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case int:
|
||||
return typed != 0
|
||||
case int32:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
case string:
|
||||
typed = strings.ToLower(strings.TrimSpace(typed))
|
||||
if typed == "true" || typed == "1" || typed == "yes" {
|
||||
return true
|
||||
}
|
||||
if typed == "false" || typed == "0" || typed == "no" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func extractSoraCameoIDs(body map[string]any) []string {
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := body["cameo_ids"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case []string:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
item = strings.TrimSpace(item)
|
||||
if item != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
str = strings.TrimSpace(str)
|
||||
if str != "" {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
|
||||
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
|
||||
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
|
||||
if displayName == "" {
|
||||
displayName = "Character"
|
||||
}
|
||||
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
|
||||
if profileAssetURL == "" {
|
||||
return nil, errors.New("profile asset url not found in cameo status")
|
||||
}
|
||||
|
||||
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
instructionSet := cameoStatus.InstructionSetHint
|
||||
if instructionSet == nil {
|
||||
instructionSet = cameoStatus.InstructionSet
|
||||
}
|
||||
|
||||
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
ProfileAssetPointer: assetPointer,
|
||||
InstructionSet: instructionSet,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.SetPublic {
|
||||
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &soraCharacterFlowResult{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
CharacterID: strings.TrimSpace(characterID),
|
||||
Username: strings.TrimSpace(username),
|
||||
DisplayName: displayName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
timeout := 10 * time.Minute
|
||||
interval := 5 * time.Second
|
||||
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
|
||||
if maxAttempts < 1 {
|
||||
maxAttempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
consecutiveErrors := 0
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= 3 {
|
||||
break
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
consecutiveErrors = 0
|
||||
if status == nil {
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
|
||||
statusMessage := strings.TrimSpace(status.StatusMessage)
|
||||
if currentStatus == "failed" {
|
||||
if statusMessage == "" {
|
||||
statusMessage = "character creation failed"
|
||||
}
|
||||
return nil, errors.New(statusMessage)
|
||||
}
|
||||
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
|
||||
return status, nil
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
|
||||
}
|
||||
return nil, errors.New("cameo processing timeout")
|
||||
}
|
||||
|
||||
func processSoraCharacterUsername(usernameHint string) string {
|
||||
usernameHint = strings.TrimSpace(usernameHint)
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
if strings.Contains(usernameHint, ".") {
|
||||
parts := strings.Split(usernameHint, ".")
|
||||
usernameHint = strings.TrimSpace(parts[len(parts)-1])
|
||||
}
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||
generationID = strings.TrimSpace(generationID)
|
||||
if generationID == "" {
|
||||
return "", "", errors.New("generation id is required for watermark-free mode")
|
||||
}
|
||||
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
postID = strings.TrimSpace(postID)
|
||||
if postID == "" {
|
||||
return "", "", errors.New("watermark-free publish returned empty post id")
|
||||
}
|
||||
|
||||
switch opts.ParseMethod {
|
||||
case "custom":
|
||||
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
|
||||
if parseErr != nil {
|
||||
return "", postID, parseErr
|
||||
}
|
||||
return strings.TrimSpace(urlVal), postID, nil
|
||||
case "", "third_party":
|
||||
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
|
||||
default:
|
||||
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 402, 403, 429, 529:
|
||||
case 401, 402, 403, 404, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
@@ -434,7 +869,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType,
|
||||
}
|
||||
if stream {
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
@@ -460,7 +906,15 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||
}
|
||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
||||
var responseHeaders http.Header
|
||||
if upstreamErr.Headers != nil {
|
||||
responseHeaders = upstreamErr.Headers.Clone()
|
||||
}
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: upstreamErr.StatusCode,
|
||||
ResponseBody: upstreamErr.Body,
|
||||
ResponseHeaders: responseHeaders,
|
||||
}
|
||||
}
|
||||
msg := upstreamErr.Message
|
||||
if override := soraProErrorMessage(model, msg); override != "" {
|
||||
@@ -505,7 +959,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
||||
return nil, errors.New("sora image generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
|
||||
interval := s.pollInterval()
|
||||
maxAttempts := s.pollMaxAttempts()
|
||||
lastPing := time.Now()
|
||||
@@ -516,7 +970,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
switch strings.ToLower(status.Status) {
|
||||
case "completed", "succeeded":
|
||||
return status.URLs, nil
|
||||
return status, nil
|
||||
case "failed":
|
||||
if status.ErrorMsg != "" {
|
||||
return nil, errors.New(status.ErrorMsg)
|
||||
@@ -620,7 +1074,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
return "", "", "", ""
|
||||
}
|
||||
if v, ok := body["remix_target_id"].(string); ok {
|
||||
remixTargetID = v
|
||||
remixTargetID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := body["image"].(string); ok {
|
||||
imageInput = v
|
||||
@@ -661,6 +1115,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
prompt = builder.String()
|
||||
}
|
||||
}
|
||||
if remixTargetID == "" {
|
||||
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
|
||||
}
|
||||
prompt = cleanRemixLinkFromPrompt(prompt)
|
||||
return prompt, imageInput, videoInput, remixTargetID
|
||||
}
|
||||
|
||||
@@ -708,6 +1166,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
|
||||
}
|
||||
}
|
||||
|
||||
func isSoraStoryboardPrompt(prompt string) bool {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return false
|
||||
}
|
||||
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
|
||||
}
|
||||
|
||||
func formatSoraStoryboardPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
|
||||
if len(matches) == 0 {
|
||||
return prompt
|
||||
}
|
||||
firstBracketPos := strings.Index(prompt, "[")
|
||||
instructions := ""
|
||||
if firstBracketPos > 0 {
|
||||
instructions = strings.TrimSpace(prompt[:firstBracketPos])
|
||||
}
|
||||
shots := make([]string, 0, len(matches))
|
||||
for i, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
duration := strings.TrimSpace(match[1])
|
||||
scene := strings.TrimSpace(match[2])
|
||||
if scene == "" {
|
||||
continue
|
||||
}
|
||||
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
|
||||
}
|
||||
if len(shots) == 0 {
|
||||
return prompt
|
||||
}
|
||||
timeline := strings.Join(shots, "\n\n")
|
||||
if instructions == "" {
|
||||
return timeline
|
||||
}
|
||||
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
|
||||
}
|
||||
|
||||
func extractRemixTargetIDFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
|
||||
}
|
||||
|
||||
func cleanRemixLinkFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return prompt
|
||||
}
|
||||
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
|
||||
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.Join(strings.Fields(cleaned), " ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
@@ -720,7 +1241,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
}
|
||||
meta := parts[0]
|
||||
payload := parts[1]
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -739,15 +1260,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraImageInput(ctx, raw)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("invalid base64 image")
|
||||
}
|
||||
return decoded, "image.png", nil
|
||||
}
|
||||
|
||||
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty video input")
|
||||
}
|
||||
if strings.HasPrefix(raw, "data:") {
|
||||
parts := strings.SplitN(raw, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid video data url")
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraVideoInput(ctx, raw)
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
parsed, err := validateSoraImageURL(rawURL)
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -761,7 +1314,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
if len(via) >= soraImageInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraImageURLValue(req.URL)
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
@@ -784,51 +1337,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: soraVideoInputTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= soraVideoInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty video content")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
|
||||
if maxBytes <= 0 {
|
||||
return nil, errors.New("invalid max bytes limit")
|
||||
}
|
||||
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
limited := io.LimitReader(decoder, maxBytes+1)
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func validateSoraRemoteURL(raw string) (*url.URL, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("empty image url")
|
||||
return nil, errors.New("empty remote url")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||
return nil, fmt.Errorf("invalid remote url: %w", err)
|
||||
}
|
||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||
if err := validateSoraRemoteURLValue(parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||
func validateSoraRemoteURLValue(parsed *url.URL) error {
|
||||
if parsed == nil {
|
||||
return errors.New("invalid image url")
|
||||
return errors.New("invalid remote url")
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return errors.New("only http/https image url is allowed")
|
||||
return errors.New("only http/https remote url is allowed")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return errors.New("image url cannot contain userinfo")
|
||||
return errors.New("remote url cannot contain userinfo")
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return errors.New("image url missing host")
|
||||
return errors.New("remote url missing host")
|
||||
}
|
||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve image url failed: %w", err)
|
||||
return fmt.Errorf("resolve remote url failed: %w", err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -4,10 +4,16 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -18,6 +24,13 @@ type stubSoraClientForPoll struct {
|
||||
videoStatus *SoraVideoTaskStatus
|
||||
imageCalls int
|
||||
videoCalls int
|
||||
enhanced string
|
||||
enhanceErr error
|
||||
storyboard bool
|
||||
videoReq SoraVideoRequest
|
||||
parseErr error
|
||||
postCalls int
|
||||
deleteCalls int
|
||||
}
|
||||
|
||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||
@@ -28,8 +41,60 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
s.videoReq = req
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
s.storyboard = true
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
return &SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
s.postCalls++
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
s.deleteCalls++
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
if s.parseErr != nil {
|
||||
return "", s.parseErr
|
||||
}
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
if s.enhanced != "" {
|
||||
return s.enhanced, s.enhanceErr
|
||||
}
|
||||
return "enhanced prompt", s.enhanceErr
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
s.imageCalls++
|
||||
return s.imageStatus, nil
|
||||
@@ -62,6 +127,136 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
||||
require.Equal(t, 1, client.imageCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
enhanced: "cinematic prompt",
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
}
|
||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/v.mp4"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, client.storyboard)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, 0, client.videoCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
parseErr: errors.New("parse failed"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 0, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 1, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
@@ -79,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
}
|
||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
|
||||
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, urls)
|
||||
require.Nil(t, status)
|
||||
require.Contains(t, err.Error(), "reject")
|
||||
require.Equal(t, 1, client.videoCalls)
|
||||
}
|
||||
@@ -175,9 +370,65 @@ func TestSoraProErrorMessage(t *testing.T) {
|
||||
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: error\n")
|
||||
require.Contains(t, body, "data: [DONE]\n\n")
|
||||
|
||||
lines := strings.Split(body, "\n")
|
||||
require.GreaterOrEqual(t, len(lines), 2)
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
||||
|
||||
data := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
|
||||
errObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
sourceHeaders := http.Header{}
|
||||
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
|
||||
err := svc.handleSoraRequestError(
|
||||
context.Background(),
|
||||
&Account{ID: 1, Platform: PlatformSora},
|
||||
&SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "forbidden",
|
||||
Headers: sourceHeaders,
|
||||
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
|
||||
},
|
||||
"sora2-landscape-10s",
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.NotNil(t, failoverErr.ResponseHeaders)
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||
|
||||
sourceHeaders.Set("cf-ray", "mutated-after-return")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||
}
|
||||
|
||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(404))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
||||
@@ -257,3 +508,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
||||
require.NotEmpty(t, data)
|
||||
require.Contains(t, filename, ".png")
|
||||
}
|
||||
|
||||
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
|
||||
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
}
|
||||
|
||||
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"watermark_free": float64(1),
|
||||
"watermark_fallback_on_failure": float64(0),
|
||||
}
|
||||
opts := parseSoraWatermarkOptions(body)
|
||||
require.True(t, opts.Enabled)
|
||||
require.False(t, opts.FallbackOnFailure)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
|
||||
Model string
|
||||
Size string
|
||||
RequirePro bool
|
||||
// Prompt-enhance 专用参数
|
||||
ExpansionLevel string
|
||||
DurationS int
|
||||
}
|
||||
|
||||
var soraModelConfigs = map[string]SoraModelConfig{
|
||||
@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
|
||||
RequirePro: true,
|
||||
},
|
||||
"prompt-enhance-short-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-short-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-short-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-medium-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-medium-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-medium-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-long-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-long-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-long-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
266
backend/internal/service/sora_request_guard.go
Normal file
266
backend/internal/service/sora_request_guard.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type soraChallengeCooldownEntry struct {
|
||||
Until time.Time
|
||||
StatusCode int
|
||||
CFRay string
|
||||
ConsecutiveChallenges int
|
||||
LastChallengeAt time.Time
|
||||
}
|
||||
|
||||
type soraSidecarSessionEntry struct {
|
||||
SessionKey string
|
||||
ExpiresAt time.Time
|
||||
LastUsedAt time.Time
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 900
|
||||
}
|
||||
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
|
||||
if cooldown <= 0 {
|
||||
return 0
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
|
||||
c.challengeCooldownMu.RLock()
|
||||
entry, ok := c.challengeCooldowns[key]
|
||||
c.challengeCooldownMu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if !entry.Until.After(now) {
|
||||
c.challengeCooldownMu.Lock()
|
||||
delete(c.challengeCooldowns, key)
|
||||
c.challengeCooldownMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
|
||||
if remaining < 1 {
|
||||
remaining = 1
|
||||
}
|
||||
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
|
||||
if entry.ConsecutiveChallenges > 1 {
|
||||
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
|
||||
}
|
||||
if entry.CFRay != "" {
|
||||
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: message,
|
||||
Headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
|
||||
|
||||
c.challengeCooldownMu.Lock()
|
||||
c.cleanupExpiredChallengeCooldownsLocked(now)
|
||||
|
||||
streak := 1
|
||||
existing, ok := c.challengeCooldowns[key]
|
||||
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
|
||||
streak = existing.ConsecutiveChallenges + 1
|
||||
}
|
||||
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
|
||||
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
|
||||
if ok && existing.Until.After(until) {
|
||||
until = existing.Until
|
||||
if existing.ConsecutiveChallenges > streak {
|
||||
streak = existing.ConsecutiveChallenges
|
||||
}
|
||||
if cfRay == "" {
|
||||
cfRay = existing.CFRay
|
||||
}
|
||||
}
|
||||
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
|
||||
Until: until,
|
||||
StatusCode: statusCode,
|
||||
CFRay: cfRay,
|
||||
ConsecutiveChallenges: streak,
|
||||
LastChallengeAt: now,
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
|
||||
if c.debugEnabled() {
|
||||
remain := int(math.Ceil(until.Sub(now).Seconds()))
|
||||
if remain < 0 {
|
||||
remain = 0
|
||||
}
|
||||
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
|
||||
}
|
||||
}
|
||||
|
||||
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
|
||||
if baseSeconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
if streak < 1 {
|
||||
streak = 1
|
||||
}
|
||||
multiplier := streak
|
||||
if multiplier > 4 {
|
||||
multiplier = 4
|
||||
}
|
||||
cooldown := baseSeconds * multiplier
|
||||
if cooldown > 3600 {
|
||||
cooldown = 3600
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
c.challengeCooldownMu.Lock()
|
||||
_, existed := c.challengeCooldowns[key]
|
||||
if existed {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
if existed && c.debugEnabled() {
|
||||
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
|
||||
if c == nil || !c.sidecarSessionReuseEnabled() {
|
||||
return ""
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return ""
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
ttlSeconds := c.sidecarSessionTTLSeconds()
|
||||
|
||||
c.sidecarSessionMu.Lock()
|
||||
defer c.sidecarSessionMu.Unlock()
|
||||
c.cleanupExpiredSidecarSessionsLocked(now)
|
||||
if existing, exists := c.sidecarSessions[key]; exists {
|
||||
existing.LastUsedAt = now
|
||||
c.sidecarSessions[key] = existing
|
||||
return existing.SessionKey
|
||||
}
|
||||
|
||||
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
|
||||
if ttlSeconds <= 0 {
|
||||
expiresAt = now.Add(365 * 24 * time.Hour)
|
||||
}
|
||||
newEntry := soraSidecarSessionEntry{
|
||||
SessionKey: "sora-" + uuid.NewString(),
|
||||
ExpiresAt: expiresAt,
|
||||
LastUsedAt: now,
|
||||
}
|
||||
c.sidecarSessions[key] = newEntry
|
||||
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
|
||||
}
|
||||
return newEntry.SessionKey
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
|
||||
if c == nil || len(c.challengeCooldowns) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.challengeCooldowns {
|
||||
if !entry.Until.After(now) {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
|
||||
if c == nil || len(c.sidecarSessions) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.sidecarSessions {
|
||||
if !entry.ExpiresAt.After(now) {
|
||||
delete(c.sidecarSessions, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func soraAccountProxyKey(account *Account, proxyURL string) string {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
|
||||
}
|
||||
|
||||
func normalizeSoraProxyKey(proxyURL string) string {
|
||||
raw := strings.TrimSpace(proxyURL)
|
||||
if raw == "" {
|
||||
return "direct"
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
port := strings.TrimSpace(parsed.Port())
|
||||
if host == "" {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
|
||||
port = ""
|
||||
}
|
||||
if port != "" {
|
||||
host = host + ":" + port
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "proxy"
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
|
||||
openAIRefresher,
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
}
|
||||
|
||||
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||
syncLinkedSora bool
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
r.soraAccountRepo = repo
|
||||
}
|
||||
|
||||
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
|
||||
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
|
||||
r.syncLinkedSora = enabled
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
|
||||
account.Type == AccountTypeOAuth
|
||||
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
}
|
||||
|
||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||
if r.accountRepo != nil {
|
||||
if r.accountRepo != nil && r.syncLinkedSora {
|
||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||
}
|
||||
|
||||
|
||||
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
|
||||
refresher := &OpenAITokenRefresher{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
accType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "openai oauth - can refresh",
|
||||
platform: PlatformOpenAI,
|
||||
accType: AccountTypeOAuth,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "sora oauth - cannot refresh directly",
|
||||
platform: PlatformSora,
|
||||
accType: AccountTypeOAuth,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "openai apikey - cannot refresh",
|
||||
platform: PlatformOpenAI,
|
||||
accType: AccountTypeAPIKey,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: tt.platform,
|
||||
Type: tt.accType,
|
||||
}
|
||||
require.Equal(t, tt.want, refresher.CanRefresh(account))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ type UsageLog struct {
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
@@ -46,6 +46,9 @@ type UsageLog struct {
|
||||
UserAgent *string
|
||||
IPAddress *string
|
||||
|
||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||
CacheTTLOverridden bool
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
ImageSize *string
|
||||
|
||||
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
}
|
||||
|
||||
func ProvideSoraDirectClient(
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
tokenProvider *OpenAITokenProvider,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
) *SoraDirectClient {
|
||||
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||
return client
|
||||
}
|
||||
|
||||
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
NewSoraDirectClient,
|
||||
ProvideSoraDirectClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
|
||||
170
backend/internal/util/soraerror/soraerror.go
Normal file
170
backend/internal/util/soraerror/soraerror.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||
htmlChallenge = []string{
|
||||
"window._cf_chl_opt",
|
||||
"just a moment",
|
||||
"enable javascript and cookies to continue",
|
||||
"__cf_chl_",
|
||||
"challenge-platform",
|
||||
}
|
||||
)
|
||||
|
||||
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
|
||||
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||
return true
|
||||
}
|
||||
|
||||
preview := strings.ToLower(TruncateBody(body, 4096))
|
||||
for _, marker := range htmlChallenge {
|
||||
if strings.Contains(preview, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
contentType := ""
|
||||
if headers != nil {
|
||||
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||
}
|
||||
if strings.Contains(contentType, "text/html") &&
|
||||
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
|
||||
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
|
||||
preview := TruncateBody(body, 8192)
|
||||
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FormatCloudflareChallengeMessage appends cf-ray info when available.
|
||||
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := ExtractCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
|
||||
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if !json.Valid([]byte(trimmed)) {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
code := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "code"),
|
||||
extractRootString(payload, "code"),
|
||||
)
|
||||
message := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "message"),
|
||||
extractRootString(payload, "message"),
|
||||
extractNestedString(payload, "error", "detail"),
|
||||
extractRootString(payload, "detail"),
|
||||
)
|
||||
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
|
||||
}
|
||||
|
||||
// TruncateBody truncates body text for logging/inspection.
|
||||
func TruncateBody(body []byte, max int) string {
|
||||
if max <= 0 {
|
||||
max = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= max {
|
||||
return raw
|
||||
}
|
||||
return raw[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func truncateMessage(s string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractRootString(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func extractNestedString(m map[string]any, parent, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
node, ok := m[parent]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
child, ok := node.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := child[key].(string)
|
||||
return s
|
||||
}
|
||||
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudflareChallengeResponse(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
|
||||
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
|
||||
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
|
||||
}
|
||||
|
||||
func TestExtractCloudflareRayID(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
|
||||
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
|
||||
}
|
||||
|
||||
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
|
||||
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
|
||||
require.Equal(t, "cf_shield_429", code)
|
||||
require.Equal(t, "rate limited", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
|
||||
require.Equal(t, "unsupported_country_code", code)
|
||||
require.Equal(t, "not available", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
|
||||
require.Equal(t, "", code)
|
||||
require.Equal(t, "plain text", msg)
|
||||
}
|
||||
|
||||
func TestFormatCloudflareChallengeMessage(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
|
||||
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
|
||||
}
|
||||
@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
path == "/health" ||
|
||||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
path == "/health" ||
|
||||
|
||||
@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
|
||||
"/api/v1/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/sora/v1/models",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
|
||||
"/api/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/sora/v1/models",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
|
||||
44
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
44
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
@@ -0,0 +1,44 @@
|
||||
-- Drop legacy cache token columns that lack the underscore separator.
|
||||
-- These were created by GORM's automatic snake_case conversion:
|
||||
-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect)
|
||||
-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect)
|
||||
--
|
||||
-- The canonical columns are:
|
||||
-- cache_creation_5m_tokens (defined in 001_init.sql)
|
||||
-- cache_creation_1h_tokens (defined in 001_init.sql)
|
||||
--
|
||||
-- Migration 009 already copied data from legacy → canonical columns.
|
||||
-- But upgraded instances may still have post-009 writes in legacy columns.
|
||||
-- Backfill once more before dropping to prevent data loss.
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'usage_logs'
|
||||
AND column_name = 'cache_creation5m_tokens'
|
||||
) THEN
|
||||
UPDATE usage_logs
|
||||
SET cache_creation_5m_tokens = cache_creation5m_tokens
|
||||
WHERE cache_creation_5m_tokens = 0
|
||||
AND cache_creation5m_tokens <> 0;
|
||||
END IF;
|
||||
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'usage_logs'
|
||||
AND column_name = 'cache_creation1h_tokens'
|
||||
) THEN
|
||||
UPDATE usage_logs
|
||||
SET cache_creation_1h_tokens = cache_creation1h_tokens
|
||||
WHERE cache_creation_1h_tokens = 0
|
||||
AND cache_creation1h_tokens <> 0;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens;
|
||||
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens;
|
||||
2
backend/migrations/055_add_cache_ttl_overridden.sql
Normal file
2
backend/migrations/055_add_cache_ttl_overridden.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account.
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
Reference in New Issue
Block a user