2026-03-07 23:59:39 +08:00
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
2026-03-08 00:14:15 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
2026-03-07 23:59:39 +08:00
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
type openAIWSRateLimitSignalRepo struct {
stubOpenAIAccountRepo
rateLimitCalls [ ] time . Time
updateExtra [ ] map [ string ] any
}
type openAICodexSnapshotAsyncRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map [ string ] any
rateLimitCh chan time . Time
}
2026-03-08 00:14:15 +08:00
type openAICodexExtraListRepo struct {
stubOpenAIAccountRepo
rateLimitCh chan time . Time
}
2026-03-07 23:59:39 +08:00
func ( r * openAIWSRateLimitSignalRepo ) SetRateLimited ( _ context . Context , _ int64 , resetAt time . Time ) error {
r . rateLimitCalls = append ( r . rateLimitCalls , resetAt )
return nil
}
func ( r * openAIWSRateLimitSignalRepo ) UpdateExtra ( _ context . Context , _ int64 , updates map [ string ] any ) error {
copied := make ( map [ string ] any , len ( updates ) )
for k , v := range updates {
copied [ k ] = v
}
r . updateExtra = append ( r . updateExtra , copied )
return nil
}
func ( r * openAICodexSnapshotAsyncRepo ) SetRateLimited ( _ context . Context , _ int64 , resetAt time . Time ) error {
if r . rateLimitCh != nil {
r . rateLimitCh <- resetAt
}
return nil
}
func ( r * openAICodexSnapshotAsyncRepo ) UpdateExtra ( _ context . Context , _ int64 , updates map [ string ] any ) error {
if r . updateExtraCh != nil {
copied := make ( map [ string ] any , len ( updates ) )
for k , v := range updates {
copied [ k ] = v
}
r . updateExtraCh <- copied
}
return nil
}
2026-03-08 00:14:15 +08:00
func ( r * openAICodexExtraListRepo ) SetRateLimited ( _ context . Context , _ int64 , resetAt time . Time ) error {
if r . rateLimitCh != nil {
r . rateLimitCh <- resetAt
}
return nil
}
func ( r * openAICodexExtraListRepo ) ListWithFilters ( _ context . Context , params pagination . PaginationParams , platform , accountType , status , search string , groupID int64 ) ( [ ] Account , * pagination . PaginationResult , error ) {
_ = platform
_ = accountType
_ = status
_ = search
_ = groupID
return r . accounts , & pagination . PaginationResult { Total : int64 ( len ( r . accounts ) ) , Page : params . Page , PageSize : params . PageSize } , nil
}
2026-03-07 23:59:39 +08:00
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
resetAt := time . Now ( ) . Add ( 2 * time . Hour ) . Unix ( )
upgrader := websocket . Upgrader { CheckOrigin : func ( r * http . Request ) bool { return true } }
wsServer := httptest . NewServer ( http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
conn , err := upgrader . Upgrade ( w , r , nil )
if err != nil {
t . Errorf ( "upgrade websocket failed: %v" , err )
return
}
defer func ( ) { _ = conn . Close ( ) } ( )
var req map [ string ] any
if err := conn . ReadJSON ( & req ) ; err != nil {
t . Errorf ( "read ws request failed: %v" , err )
return
}
_ = conn . WriteJSON ( map [ string ] any {
"type" : "error" ,
"error" : map [ string ] any {
"code" : "rate_limit_exceeded" ,
"type" : "usage_limit_reached" ,
"message" : "The usage limit has been reached" ,
"resets_at" : resetAt ,
} ,
} )
} ) )
defer wsServer . Close ( )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
c . Request . Header . Set ( "User-Agent" , "unit-test-agent/1.0" )
upstream := & httpUpstreamRecorder {
resp : & http . Response {
StatusCode : http . StatusOK ,
Header : http . Header { "Content-Type" : [ ] string { "application/json" } } ,
Body : io . NopCloser ( strings . NewReader ( ` { "id":"resp_http_should_not_run"} ` ) ) ,
} ,
}
cfg := newOpenAIWSV2TestConfig ( )
cfg . Security . URLAllowlist . Enabled = false
cfg . Security . URLAllowlist . AllowInsecureHTTP = true
account := Account {
ID : 501 ,
Name : "openai-ws-rate-limit-event" ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Credentials : map [ string ] any {
"api_key" : "sk-test" ,
"base_url" : wsServer . URL ,
} ,
Extra : map [ string ] any {
"responses_websockets_v2_enabled" : true ,
} ,
}
repo := & openAIWSRateLimitSignalRepo { stubOpenAIAccountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { account } } }
rateSvc := & RateLimitService { accountRepo : repo }
svc := & OpenAIGatewayService {
accountRepo : repo ,
rateLimitService : rateSvc ,
httpUpstream : upstream ,
cache : & stubGatewayCache { } ,
cfg : cfg ,
openaiWSResolver : NewOpenAIWSProtocolResolver ( cfg ) ,
toolCorrector : NewCodexToolCorrector ( ) ,
}
body := [ ] byte ( ` { "model":"gpt-5.1","stream":false,"input":[ { "type":"input_text","text":"hello"}]} ` )
result , err := svc . Forward ( context . Background ( ) , c , & account , body )
require . Error ( t , err )
require . Nil ( t , result )
require . Equal ( t , http . StatusTooManyRequests , rec . Code )
require . Nil ( t , upstream . lastReq , "WS 限流 error event 不应回退到同账号 HTTP" )
require . Len ( t , repo . rateLimitCalls , 1 )
require . WithinDuration ( t , time . Unix ( resetAt , 0 ) , repo . rateLimitCalls [ 0 ] , 2 * time . Second )
}
func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
server := httptest . NewServer ( http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "x-codex-primary-used-percent" , "100" )
w . Header ( ) . Set ( "x-codex-primary-reset-after-seconds" , "7200" )
w . Header ( ) . Set ( "x-codex-primary-window-minutes" , "10080" )
w . Header ( ) . Set ( "x-codex-secondary-used-percent" , "3" )
w . Header ( ) . Set ( "x-codex-secondary-reset-after-seconds" , "1800" )
w . Header ( ) . Set ( "x-codex-secondary-window-minutes" , "300" )
w . WriteHeader ( http . StatusTooManyRequests )
_ , _ = w . Write ( [ ] byte ( ` { "error": { "type":"rate_limit_exceeded","message":"rate limited"}} ` ) )
} ) )
defer server . Close ( )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
c . Request . Header . Set ( "User-Agent" , "unit-test-agent/1.0" )
upstream := & httpUpstreamRecorder {
resp : & http . Response {
StatusCode : http . StatusOK ,
Header : http . Header { "Content-Type" : [ ] string { "application/json" } } ,
Body : io . NopCloser ( strings . NewReader ( ` { "id":"resp_http_should_not_run"} ` ) ) ,
} ,
}
cfg := newOpenAIWSV2TestConfig ( )
cfg . Security . URLAllowlist . Enabled = false
cfg . Security . URLAllowlist . AllowInsecureHTTP = true
account := Account {
ID : 502 ,
Name : "openai-ws-rate-limit-handshake" ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Credentials : map [ string ] any {
"api_key" : "sk-test" ,
"base_url" : server . URL ,
} ,
Extra : map [ string ] any {
"responses_websockets_v2_enabled" : true ,
} ,
}
repo := & openAIWSRateLimitSignalRepo { stubOpenAIAccountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { account } } }
rateSvc := & RateLimitService { accountRepo : repo }
svc := & OpenAIGatewayService {
accountRepo : repo ,
rateLimitService : rateSvc ,
httpUpstream : upstream ,
cache : & stubGatewayCache { } ,
cfg : cfg ,
openaiWSResolver : NewOpenAIWSProtocolResolver ( cfg ) ,
toolCorrector : NewCodexToolCorrector ( ) ,
}
body := [ ] byte ( ` { "model":"gpt-5.1","stream":false,"input":[ { "type":"input_text","text":"hello"}]} ` )
result , err := svc . Forward ( context . Background ( ) , c , & account , body )
require . Error ( t , err )
require . Nil ( t , result )
require . Equal ( t , http . StatusTooManyRequests , rec . Code )
require . Nil ( t , upstream . lastReq , "WS 握手 429 不应回退到同账号 HTTP" )
require . Len ( t , repo . rateLimitCalls , 1 )
require . NotEmpty ( t , repo . updateExtra , "握手 429 的 x-codex 头应立即落库" )
require . Contains ( t , repo . updateExtra [ 0 ] , "codex_usage_updated_at" )
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := newOpenAIWSV2TestConfig ( )
cfg . Security . URLAllowlist . Enabled = false
cfg . Security . URLAllowlist . AllowInsecureHTTP = true
cfg . Gateway . OpenAIWS . MaxConnsPerAccount = 1
cfg . Gateway . OpenAIWS . MinIdlePerAccount = 0
cfg . Gateway . OpenAIWS . MaxIdlePerAccount = 1
cfg . Gateway . OpenAIWS . QueueLimitPerConn = 8
cfg . Gateway . OpenAIWS . DialTimeoutSeconds = 3
cfg . Gateway . OpenAIWS . ReadTimeoutSeconds = 3
cfg . Gateway . OpenAIWS . WriteTimeoutSeconds = 3
resetAt := time . Now ( ) . Add ( 90 * time . Minute ) . Unix ( )
captureConn := & openAIWSCaptureConn {
events : [ ] [ ] byte {
[ ] byte ( ` { "type":"error","error": { "code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}} ` ) ,
} ,
}
captureConn . events [ 0 ] = [ ] byte ( strings . ReplaceAll ( string ( captureConn . events [ 0 ] ) , "PLACEHOLDER" , strconv . FormatInt ( resetAt , 10 ) ) )
captureDialer := & openAIWSCaptureDialer { conn : captureConn }
pool := newOpenAIWSConnPool ( cfg )
pool . setClientDialerForTest ( captureDialer )
account := Account {
ID : 503 ,
Name : "openai-ingress-rate-limit" ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Credentials : map [ string ] any {
"api_key" : "sk-test" ,
} ,
Extra : map [ string ] any {
"responses_websockets_v2_enabled" : true ,
} ,
}
repo := & openAIWSRateLimitSignalRepo { stubOpenAIAccountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { account } } }
rateSvc := & RateLimitService { accountRepo : repo }
svc := & OpenAIGatewayService {
accountRepo : repo ,
rateLimitService : rateSvc ,
httpUpstream : & httpUpstreamRecorder { } ,
cache : & stubGatewayCache { } ,
cfg : cfg ,
openaiWSResolver : NewOpenAIWSProtocolResolver ( cfg ) ,
toolCorrector : NewCodexToolCorrector ( ) ,
openaiWSPool : pool ,
}
serverErrCh := make ( chan error , 1 )
wsServer := httptest . NewServer ( http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
conn , err := coderws . Accept ( w , r , & coderws . AcceptOptions { CompressionMode : coderws . CompressionContextTakeover } )
if err != nil {
serverErrCh <- err
return
}
defer func ( ) { _ = conn . CloseNow ( ) } ( )
rec := httptest . NewRecorder ( )
ginCtx , _ := gin . CreateTestContext ( rec )
req := r . Clone ( r . Context ( ) )
req . Header = req . Header . Clone ( )
req . Header . Set ( "User-Agent" , "unit-test-agent/1.0" )
ginCtx . Request = req
readCtx , cancel := context . WithTimeout ( r . Context ( ) , 3 * time . Second )
msgType , firstMessage , readErr := conn . Read ( readCtx )
cancel ( )
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws . MessageText && msgType != coderws . MessageBinary {
serverErrCh <- io . ErrUnexpectedEOF
return
}
serverErrCh <- svc . ProxyResponsesWebSocketFromClient ( r . Context ( ) , ginCtx , conn , & account , "sk-test" , firstMessage , nil )
} ) )
defer wsServer . Close ( )
dialCtx , cancelDial := context . WithTimeout ( context . Background ( ) , 3 * time . Second )
clientConn , _ , err := coderws . Dial ( dialCtx , "ws" + strings . TrimPrefix ( wsServer . URL , "http" ) , nil )
cancelDial ( )
require . NoError ( t , err )
defer func ( ) { _ = clientConn . CloseNow ( ) } ( )
writeCtx , cancelWrite := context . WithTimeout ( context . Background ( ) , 3 * time . Second )
err = clientConn . Write ( writeCtx , coderws . MessageText , [ ] byte ( ` { "type":"response.create","model":"gpt-5.1","stream":false} ` ) )
cancelWrite ( )
require . NoError ( t , err )
select {
case serverErr := <- serverErrCh :
require . Error ( t , serverErr )
require . Len ( t , repo . rateLimitCalls , 1 )
require . WithinDuration ( t , time . Unix ( resetAt , 0 ) , repo . rateLimitCalls [ 0 ] , 2 * time . Second )
case <- time . After ( 5 * time . Second ) :
t . Fatal ( "等待 ingress websocket 结束超时" )
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit ( t * testing . T ) {
repo := & openAICodexSnapshotAsyncRepo {
updateExtraCh : make ( chan map [ string ] any , 1 ) ,
rateLimitCh : make ( chan time . Time , 1 ) ,
}
svc := & OpenAIGatewayService { accountRepo : repo }
snapshot := & OpenAICodexUsageSnapshot {
PrimaryUsedPercent : ptrFloat64WS ( 100 ) ,
PrimaryResetAfterSeconds : ptrIntWS ( 3600 ) ,
PrimaryWindowMinutes : ptrIntWS ( 10080 ) ,
SecondaryUsedPercent : ptrFloat64WS ( 12 ) ,
SecondaryResetAfterSeconds : ptrIntWS ( 1200 ) ,
SecondaryWindowMinutes : ptrIntWS ( 300 ) ,
}
before := time . Now ( )
svc . updateCodexUsageSnapshot ( context . Background ( ) , 601 , snapshot )
select {
case updates := <- repo . updateExtraCh :
require . Equal ( t , 100.0 , updates [ "codex_7d_used_percent" ] )
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "等待 codex 快照落库超时" )
}
select {
case resetAt := <- repo . rateLimitCh :
require . WithinDuration ( t , before . Add ( time . Hour ) , resetAt , 2 * time . Second )
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "等待 codex 100% 自动切换限流超时" )
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit ( t * testing . T ) {
repo := & openAICodexSnapshotAsyncRepo {
updateExtraCh : make ( chan map [ string ] any , 1 ) ,
rateLimitCh : make ( chan time . Time , 1 ) ,
}
svc := & OpenAIGatewayService { accountRepo : repo }
snapshot := & OpenAICodexUsageSnapshot {
PrimaryUsedPercent : ptrFloat64WS ( 94 ) ,
PrimaryResetAfterSeconds : ptrIntWS ( 3600 ) ,
PrimaryWindowMinutes : ptrIntWS ( 10080 ) ,
SecondaryUsedPercent : ptrFloat64WS ( 22 ) ,
SecondaryResetAfterSeconds : ptrIntWS ( 1200 ) ,
SecondaryWindowMinutes : ptrIntWS ( 300 ) ,
}
svc . updateCodexUsageSnapshot ( context . Background ( ) , 602 , snapshot )
select {
case <- repo . updateExtraCh :
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "等待 codex 快照落库超时" )
}
select {
case resetAt := <- repo . rateLimitCh :
t . Fatalf ( "unexpected rate limit reset at: %v" , resetAt )
case <- time . After ( 200 * time . Millisecond ) :
}
}
func ptrFloat64WS ( v float64 ) * float64 { return & v }
func ptrIntWS ( v int ) * int { return & v }
2026-03-08 00:14:15 +08:00
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit ( t * testing . T ) {
resetAt := time . Now ( ) . Add ( 6 * 24 * time . Hour )
account := Account {
ID : 701 ,
Platform : PlatformOpenAI ,
Type : AccountTypeOAuth ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Extra : map [ string ] any {
"codex_7d_used_percent" : 100.0 ,
"codex_7d_reset_at" : resetAt . UTC ( ) . Format ( time . RFC3339 ) ,
} ,
}
repo := & openAICodexExtraListRepo { stubOpenAIAccountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { account } } , rateLimitCh : make ( chan time . Time , 1 ) }
svc := & OpenAIGatewayService { accountRepo : repo }
fresh , err := svc . getSchedulableAccount ( context . Background ( ) , account . ID )
require . NoError ( t , err )
require . NotNil ( t , fresh )
require . NotNil ( t , fresh . RateLimitResetAt )
require . WithinDuration ( t , resetAt . UTC ( ) , * fresh . RateLimitResetAt , time . Second )
select {
case persisted := <- repo . rateLimitCh :
require . WithinDuration ( t , resetAt . UTC ( ) , persisted , time . Second )
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "等待旧快照补写限流状态超时" )
}
}
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount ( t * testing . T ) {
resetAt := time . Now ( ) . Add ( 4 * 24 * time . Hour )
repo := & openAICodexExtraListRepo {
stubOpenAIAccountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { {
ID : 702 ,
Platform : PlatformOpenAI ,
Type : AccountTypeOAuth ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Extra : map [ string ] any {
"codex_7d_used_percent" : 100.0 ,
"codex_7d_reset_at" : resetAt . UTC ( ) . Format ( time . RFC3339 ) ,
} ,
} } } ,
rateLimitCh : make ( chan time . Time , 1 ) ,
}
svc := & adminServiceImpl { accountRepo : repo }
accounts , total , err := svc . ListAccounts ( context . Background ( ) , 1 , 20 , PlatformOpenAI , AccountTypeOAuth , "" , "" , 0 )
require . NoError ( t , err )
require . Equal ( t , int64 ( 1 ) , total )
require . Len ( t , accounts , 1 )
require . NotNil ( t , accounts [ 0 ] . RateLimitResetAt )
require . WithinDuration ( t , resetAt . UTC ( ) , * accounts [ 0 ] . RateLimitResetAt , time . Second )
select {
case persisted := <- repo . rateLimitCh :
require . WithinDuration ( t , resetAt . UTC ( ) , persisted , time . Second )
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "等待列表补写限流状态超时" )
}
}
2026-03-07 23:59:39 +08:00
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429 ( t * testing . T ) {
require . Equal ( t , http . StatusTooManyRequests , openAIWSErrorHTTPStatusFromRaw ( "" , "usage_limit_reached" ) )
require . Equal ( t , http . StatusTooManyRequests , openAIWSErrorHTTPStatusFromRaw ( "rate_limit_exceeded" , "" ) )
}