2026-01-04 19:49:59 +08:00
package service
import (
"bufio"
2026-01-05 13:54:43 +08:00
"bytes"
2026-01-13 22:49:26 -08:00
"context"
2026-01-04 19:49:59 +08:00
"errors"
2026-02-28 15:01:20 +08:00
"fmt"
2026-01-04 19:49:59 +08:00
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
2026-03-07 14:12:38 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
2026-02-28 15:01:20 +08:00
"github.com/cespare/xxhash/v2"
2026-01-04 19:49:59 +08:00
"github.com/gin-gonic/gin"
2026-02-06 22:55:12 +08:00
"github.com/stretchr/testify/require"
2026-01-04 19:49:59 +08:00
)
2026-02-08 12:05:39 +08:00
// 编译期接口断言
var _ AccountRepository = ( * stubOpenAIAccountRepo ) ( nil )
var _ GatewayCache = ( * stubGatewayCache ) ( nil )
2026-01-13 22:49:26 -08:00
type stubOpenAIAccountRepo struct {
AccountRepository
accounts [ ] Account
}
2026-03-06 20:46:10 +08:00
type snapshotUpdateAccountRepo struct {
stubOpenAIAccountRepo
updateExtraCalls chan map [ string ] any
}
func ( r * snapshotUpdateAccountRepo ) UpdateExtra ( ctx context . Context , id int64 , updates map [ string ] any ) error {
if r . updateExtraCalls != nil {
copied := make ( map [ string ] any , len ( updates ) )
for k , v := range updates {
copied [ k ] = v
}
r . updateExtraCalls <- copied
}
return nil
}
2026-01-20 11:19:32 +08:00
func ( r stubOpenAIAccountRepo ) GetByID ( ctx context . Context , id int64 ) ( * Account , error ) {
for i := range r . accounts {
if r . accounts [ i ] . ID == id {
return & r . accounts [ i ] , nil
}
}
return nil , errors . New ( "account not found" )
}
2026-01-13 22:49:26 -08:00
func ( r stubOpenAIAccountRepo ) ListSchedulableByGroupIDAndPlatform ( ctx context . Context , groupID int64 , platform string ) ( [ ] Account , error ) {
2026-01-20 11:19:32 +08:00
var result [ ] Account
for _ , acc := range r . accounts {
if acc . Platform == platform {
result = append ( result , acc )
}
}
return result , nil
2026-01-13 22:49:26 -08:00
}
func ( r stubOpenAIAccountRepo ) ListSchedulableByPlatform ( ctx context . Context , platform string ) ( [ ] Account , error ) {
2026-01-20 11:19:32 +08:00
var result [ ] Account
for _ , acc := range r . accounts {
if acc . Platform == platform {
result = append ( result , acc )
}
}
return result , nil
2026-01-13 22:49:26 -08:00
}
2026-03-03 13:10:26 +08:00
func ( r stubOpenAIAccountRepo ) ListSchedulableUngroupedByPlatform ( ctx context . Context , platform string ) ( [ ] Account , error ) {
return r . ListSchedulableByPlatform ( ctx , platform )
}
2026-01-13 22:49:26 -08:00
type stubConcurrencyCache struct {
ConcurrencyCache
2026-01-20 11:19:32 +08:00
loadBatchErr error
loadMap map [ int64 ] * AccountLoadInfo
acquireResults map [ int64 ] bool
waitCounts map [ int64 ] int
skipDefaultLoad bool
2026-01-13 22:49:26 -08:00
}
2026-01-15 21:42:13 +08:00
type cancelReadCloser struct { }
func ( c cancelReadCloser ) Read ( p [ ] byte ) ( int , error ) { return 0 , context . Canceled }
func ( c cancelReadCloser ) Close ( ) error { return nil }
2026-01-15 21:51:14 +08:00
type failingGinWriter struct {
gin . ResponseWriter
failAfter int
writes int
}
func ( w * failingGinWriter ) Write ( p [ ] byte ) ( int , error ) {
if w . writes >= w . failAfter {
return 0 , errors . New ( "write failed" )
}
w . writes ++
return w . ResponseWriter . Write ( p )
}
2026-01-13 22:49:26 -08:00
func ( c stubConcurrencyCache ) AcquireAccountSlot ( ctx context . Context , accountID int64 , maxConcurrency int , requestID string ) ( bool , error ) {
2026-01-20 11:19:32 +08:00
if c . acquireResults != nil {
if result , ok := c . acquireResults [ accountID ] ; ok {
return result , nil
}
}
2026-01-13 22:49:26 -08:00
return true , nil
}
func ( c stubConcurrencyCache ) ReleaseAccountSlot ( ctx context . Context , accountID int64 , requestID string ) error {
return nil
}
func ( c stubConcurrencyCache ) GetAccountsLoadBatch ( ctx context . Context , accounts [ ] AccountWithConcurrency ) ( map [ int64 ] * AccountLoadInfo , error ) {
2026-01-20 11:19:32 +08:00
if c . loadBatchErr != nil {
return nil , c . loadBatchErr
}
2026-01-13 22:49:26 -08:00
out := make ( map [ int64 ] * AccountLoadInfo , len ( accounts ) )
2026-01-20 11:19:32 +08:00
if c . skipDefaultLoad && c . loadMap != nil {
for _ , acc := range accounts {
if load , ok := c . loadMap [ acc . ID ] ; ok {
out [ acc . ID ] = load
}
}
return out , nil
}
2026-01-13 22:49:26 -08:00
for _ , acc := range accounts {
2026-01-20 11:19:32 +08:00
if c . loadMap != nil {
if load , ok := c . loadMap [ acc . ID ] ; ok {
out [ acc . ID ] = load
continue
}
}
2026-01-13 22:49:26 -08:00
out [ acc . ID ] = & AccountLoadInfo { AccountID : acc . ID , LoadRate : 0 }
}
return out , nil
}
2026-01-17 02:31:16 +08:00
func TestOpenAIGatewayService_GenerateSessionHash_Priority ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
svc := & OpenAIGatewayService { }
2026-02-10 08:59:30 +08:00
bodyWithKey := [ ] byte ( ` { "prompt_cache_key":"ses_aaa"} ` )
2026-01-17 02:31:16 +08:00
// 1) session_id header wins
c . Request . Header . Set ( "session_id" , "sess-123" )
c . Request . Header . Set ( "conversation_id" , "conv-456" )
2026-02-10 08:59:30 +08:00
h1 := svc . GenerateSessionHash ( c , bodyWithKey )
2026-01-17 02:31:16 +08:00
if h1 == "" {
t . Fatalf ( "expected non-empty hash" )
}
// 2) conversation_id used when session_id absent
c . Request . Header . Del ( "session_id" )
2026-02-10 08:59:30 +08:00
h2 := svc . GenerateSessionHash ( c , bodyWithKey )
2026-01-17 02:31:16 +08:00
if h2 == "" {
t . Fatalf ( "expected non-empty hash" )
}
if h1 == h2 {
t . Fatalf ( "expected different hashes for different keys" )
}
// 3) prompt_cache_key used when both headers absent
c . Request . Header . Del ( "conversation_id" )
2026-02-10 08:59:30 +08:00
h3 := svc . GenerateSessionHash ( c , bodyWithKey )
2026-01-17 02:31:16 +08:00
if h3 == "" {
t . Fatalf ( "expected non-empty hash" )
}
if h2 == h3 {
t . Fatalf ( "expected different hashes for different keys" )
}
// 4) empty when no signals
2026-02-10 08:59:30 +08:00
h4 := svc . GenerateSessionHash ( c , [ ] byte ( ` { } ` ) )
2026-01-17 02:31:16 +08:00
if h4 != "" {
t . Fatalf ( "expected empty hash when no signals" )
}
}
2026-02-28 15:01:20 +08:00
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64 ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
c . Request . Header . Set ( "session_id" , "sess-fixed-value" )
svc := & OpenAIGatewayService { }
got := svc . GenerateSessionHash ( c , nil )
want := fmt . Sprintf ( "%016x" , xxhash . Sum64String ( "sess-fixed-value" ) )
require . Equal ( t , want , got )
}
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
c . Request . Header . Set ( "session_id" , "sess-legacy-check" )
svc := & OpenAIGatewayService { }
sessionHash := svc . GenerateSessionHash ( c , nil )
require . NotEmpty ( t , sessionHash )
require . NotNil ( t , c . Request )
require . NotNil ( t , c . Request . Context ( ) )
require . NotEmpty ( t , openAILegacySessionHashFromContext ( c . Request . Context ( ) ) )
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/responses" , nil )
svc := & OpenAIGatewayService { }
seed := "openai_ws_ingress:9:100:200"
got := svc . GenerateSessionHashWithFallback ( c , [ ] byte ( ` { } ` ) , seed )
want := fmt . Sprintf ( "%016x" , xxhash . Sum64String ( seed ) )
require . Equal ( t , want , got )
require . NotEmpty ( t , openAILegacySessionHashFromContext ( c . Request . Context ( ) ) )
empty := svc . GenerateSessionHashWithFallback ( c , [ ] byte ( ` { } ` ) , " " )
require . Equal ( t , "" , empty )
}
test(gateway): add tests for content-based session hash fallback
- 20 unit tests for deriveOpenAIContentSessionSeed covering:
- Empty/nil inputs, model-only, stable across turns
- Different model/system/first-user produce different seeds
- Tools, functions, developer role, structured content
- Responses API: input string, input array, instructions, input_text typed items
- JSON canonicalization (whitespace/key-order insensitive)
- Prefix presence, empty tools ignored, messages preferred over input
- 3 integration tests for GenerateSessionHash content fallback:
- Content fallback produces stable hash
- Explicit signals override content fallback
- Empty body still returns empty hash
2026-04-02 00:11:17 +08:00
func TestOpenAIGatewayService_GenerateSessionHash_ContentFallback ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/chat/completions" , nil )
svc := & OpenAIGatewayService { }
body := [ ] byte ( ` { "model":"gpt-5.4","messages":[ { "role":"system","content":"You are helpful."}, { "role":"user","content":"Hello"}]} ` )
hash := svc . GenerateSessionHash ( c , body )
require . NotEmpty ( t , hash , "content-based fallback should produce a hash" )
hash2 := svc . GenerateSessionHash ( c , body )
require . Equal ( t , hash , hash2 , "same content should produce same hash" )
bodyExtended := [ ] byte ( ` { "model":"gpt-5.4","messages":[ { "role":"system","content":"You are helpful."}, { "role":"user","content":"Hello"}, { "role":"assistant","content":"Hi!"}, { "role":"user","content":"How are you?"}]} ` )
hashExtended := svc . GenerateSessionHash ( c , bodyExtended )
require . Equal ( t , hash , hashExtended , "hash should be stable across later turns" )
bodyDifferent := [ ] byte ( ` { "model":"gpt-5.4","messages":[ { "role":"user","content":"Different question"}]} ` )
hashDifferent := svc . GenerateSessionHash ( c , bodyDifferent )
require . NotEqual ( t , hash , hashDifferent , "different content should produce different hash" )
}
func TestOpenAIGatewayService_GenerateSessionHash_ExplicitSignalWinsOverContent ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/chat/completions" , nil )
svc := & OpenAIGatewayService { }
body := [ ] byte ( ` { "model":"gpt-5.4","messages":[ { "role":"user","content":"Hello"}]} ` )
contentHash := svc . GenerateSessionHash ( c , body )
require . NotEmpty ( t , contentHash )
c . Request . Header . Set ( "session_id" , "explicit-session" )
explicitHash := svc . GenerateSessionHash ( c , body )
require . NotEmpty ( t , explicitHash )
require . NotEqual ( t , contentHash , explicitHash , "explicit session_id should override content fallback" )
}
func TestOpenAIGatewayService_GenerateSessionHash_EmptyBodyStillEmpty ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/openai/v1/chat/completions" , nil )
svc := & OpenAIGatewayService { }
require . Empty ( t , svc . GenerateSessionHash ( c , [ ] byte ( ` { } ` ) ) )
require . Empty ( t , svc . GenerateSessionHash ( c , nil ) )
}
2026-01-20 11:19:32 +08:00
func ( c stubConcurrencyCache ) GetAccountWaitingCount ( ctx context . Context , accountID int64 ) ( int , error ) {
if c . waitCounts != nil {
if count , ok := c . waitCounts [ accountID ] ; ok {
return count , nil
}
}
return 0 , nil
}
type stubGatewayCache struct {
sessionBindings map [ string ] int64
deletedSessions map [ string ] int
}
func ( c * stubGatewayCache ) GetSessionAccountID ( ctx context . Context , groupID int64 , sessionHash string ) ( int64 , error ) {
if id , ok := c . sessionBindings [ sessionHash ] ; ok {
return id , nil
}
return 0 , errors . New ( "not found" )
}
func ( c * stubGatewayCache ) SetSessionAccountID ( ctx context . Context , groupID int64 , sessionHash string , accountID int64 , ttl time . Duration ) error {
if c . sessionBindings == nil {
c . sessionBindings = make ( map [ string ] int64 )
}
c . sessionBindings [ sessionHash ] = accountID
return nil
}
func ( c * stubGatewayCache ) RefreshSessionTTL ( ctx context . Context , groupID int64 , sessionHash string , ttl time . Duration ) error {
return nil
}
func ( c * stubGatewayCache ) DeleteSessionAccountID ( ctx context . Context , groupID int64 , sessionHash string ) error {
if c . sessionBindings == nil {
return nil
}
if c . deletedSessions == nil {
c . deletedSessions = make ( map [ string ] int )
}
c . deletedSessions [ sessionHash ] ++
delete ( c . sessionBindings , sessionHash )
return nil
}
2026-01-13 22:49:26 -08:00
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable ( t * testing . T ) {
now := time . Now ( )
resetAt := now . Add ( 10 * time . Minute )
groupID := int64 ( 1 )
rateLimited := Account {
ID : 1 ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Priority : 0 ,
RateLimitResetAt : & resetAt ,
}
available := Account {
ID : 2 ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Priority : 1 ,
}
svc := & OpenAIGatewayService {
accountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { rateLimited , available } } ,
concurrencyService : NewConcurrencyService ( stubConcurrencyCache { } ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-5.2" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil {
t . Fatalf ( "expected selection with account" )
}
if selection . Account . ID != available . ID {
t . Fatalf ( "expected account %d, got %d" , available . ID , selection . Account . ID )
}
if selection . ReleaseFunc != nil {
selection . ReleaseFunc ( )
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService ( t * testing . T ) {
now := time . Now ( )
resetAt := now . Add ( 10 * time . Minute )
groupID := int64 ( 1 )
rateLimited := Account {
ID : 1 ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Priority : 0 ,
RateLimitResetAt : & resetAt ,
}
available := Account {
ID : 2 ,
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Status : StatusActive ,
Schedulable : true ,
Concurrency : 1 ,
Priority : 1 ,
}
svc := & OpenAIGatewayService {
accountRepo : stubOpenAIAccountRepo { accounts : [ ] Account { rateLimited , available } } ,
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-5.2" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil {
t . Fatalf ( "expected selection with account" )
}
if selection . Account . ID != available . ID {
t . Fatalf ( "expected account %d, got %d" , available . ID , selection . Account . ID )
}
if selection . ReleaseFunc != nil {
selection . ReleaseFunc ( )
}
}
2026-01-20 11:19:32 +08:00
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession ( t * testing . T ) {
sessionHash := "session-1"
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusDisabled , Schedulable : true , Concurrency : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 } ,
} ,
}
cache := & stubGatewayCache {
sessionBindings : map [ string ] int64 { "openai:" + sessionHash : 1 } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , sessionHash , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountForModelWithExclusions error: %v" , err )
}
if acc == nil || acc . ID != 2 {
t . Fatalf ( "expected account 2, got %+v" , acc )
}
if cache . deletedSessions [ "openai:" + sessionHash ] != 1 {
t . Fatalf ( "expected sticky session to be deleted" )
}
if cache . sessionBindings [ "openai:" + sessionHash ] != 2 {
t . Fatalf ( "expected sticky session to bind to account 2" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession ( t * testing . T ) {
sessionHash := "session-2"
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusDisabled , Schedulable : true , Concurrency : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 } ,
} ,
}
cache := & stubGatewayCache {
sessionBindings : map [ string ] int64 { "openai:" + sessionHash : 1 } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( stubConcurrencyCache { } ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , sessionHash , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil || selection . Account . ID != 2 {
t . Fatalf ( "expected account 2, got %+v" , selection )
}
if cache . deletedSessions [ "openai:" + sessionHash ] != 1 {
t . Fatalf ( "expected sticky session to be deleted" )
}
if cache . sessionBindings [ "openai:" + sessionHash ] != 2 {
t . Fatalf ( "expected sticky session to bind to account 2" )
}
if selection . ReleaseFunc != nil {
selection . ReleaseFunc ( )
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport ( t * testing . T ) {
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{
ID : 1 ,
Platform : PlatformOpenAI ,
Status : StatusActive ,
Schedulable : true ,
Credentials : map [ string ] any { "model_mapping" : map [ string ] any { "gpt-3.5-turbo" : "gpt-3.5-turbo" } } ,
} ,
} ,
}
cache := & stubGatewayCache { }
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , "" , "gpt-4" , nil )
if err == nil {
t . Fatalf ( "expected error for unsupported model" )
}
if acc != nil {
t . Fatalf ( "expected nil account for unsupported model" )
}
if ! strings . Contains ( err . Error ( ) , "supporting model" ) {
t . Fatalf ( "unexpected error: %v" , err )
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 2 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadBatchErr : errors . New ( "load batch failed" ) ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "fallback" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil {
t . Fatalf ( "expected selection" )
}
if selection . Account . ID != 2 {
t . Fatalf ( "expected account 2, got %d" , selection . Account . ID )
}
if cache . sessionBindings [ "openai:fallback" ] != 2 {
t . Fatalf ( "expected sticky session updated" )
}
if selection . ReleaseFunc != nil {
selection . ReleaseFunc ( )
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
acquireResults : map [ int64 ] bool { 1 : false } ,
loadMap : map [ int64 ] * AccountLoadInfo {
1 : { AccountID : 1 , LoadRate : 10 } ,
} ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . WaitPlan == nil {
t . Fatalf ( "expected wait plan fallback" )
}
if selection . Account == nil || selection . Account . ID != 1 {
t . Fatalf ( "expected account 1" )
}
}
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding ( t * testing . T ) {
sessionHash := "bind"
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , sessionHash , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountForModelWithExclusions error: %v" , err )
}
if acc == nil || acc . ID != 1 {
t . Fatalf ( "expected account 1" )
}
if cache . sessionBindings [ "openai:" + sessionHash ] != 1 {
t . Fatalf ( "expected sticky session binding" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan ( t * testing . T ) {
sessionHash := "sticky-wait"
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache {
sessionBindings : map [ string ] int64 { "openai:" + sessionHash : 1 } ,
}
concurrencyCache := stubConcurrencyCache {
acquireResults : map [ int64 ] bool { 1 : false } ,
waitCounts : map [ int64 ] int { 1 : 0 } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , sessionHash , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . WaitPlan == nil {
t . Fatalf ( "expected sticky wait plan" )
}
if selection . Account == nil || selection . Account . ID != 1 {
t . Fatalf ( "expected account 1" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadMap : map [ int64 ] * AccountLoadInfo {
1 : { AccountID : 1 , LoadRate : 80 } ,
2 : { AccountID : 2 , LoadRate : 10 } ,
} ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "load" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil || selection . Account . ID != 2 {
t . Fatalf ( "expected account 2" )
}
if cache . sessionBindings [ "openai:load" ] != 2 {
t . Fatalf ( "expected sticky session updated" )
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback ( t * testing . T ) {
sessionHash := "excluded"
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 2 } ,
} ,
}
cache := & stubGatewayCache {
sessionBindings : map [ string ] int64 { "openai:" + sessionHash : 1 } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
excluded := map [ int64 ] struct { } { 1 : { } }
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , sessionHash , "gpt-4" , excluded )
if err != nil {
t . Fatalf ( "SelectAccountForModelWithExclusions error: %v" , err )
}
if acc == nil || acc . ID != 2 {
t . Fatalf ( "expected account 2" )
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI ( t * testing . T ) {
sessionHash := "non-openai"
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformAnthropic , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 2 } ,
} ,
}
cache := & stubGatewayCache {
sessionBindings : map [ string ] int64 { "openai:" + sessionHash : 1 } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , sessionHash , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountForModelWithExclusions error: %v" , err )
}
if acc == nil || acc . ID != 2 {
t . Fatalf ( "expected account 2" )
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts ( t * testing . T ) {
repo := stubOpenAIAccountRepo { accounts : [ ] Account { } }
cache := & stubGatewayCache { }
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , "" , "" , nil )
if err == nil {
t . Fatalf ( "expected error for no accounts" )
}
if acc != nil {
t . Fatalf ( "expected nil account" )
}
if ! strings . Contains ( err . Error ( ) , "no available OpenAI accounts" ) {
t . Fatalf ( "unexpected error: %v" , err )
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates ( t * testing . T ) {
groupID := int64 ( 1 )
resetAt := time . Now ( ) . Add ( 1 * time . Hour )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 , RateLimitResetAt : & resetAt } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache { }
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err == nil {
t . Fatalf ( "expected error for no candidates" )
}
if selection != nil {
t . Fatalf ( "expected nil selection" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadMap : map [ int64 ] * AccountLoadInfo {
1 : { AccountID : 1 , LoadRate : 100 } ,
} ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . WaitPlan == nil {
t . Fatalf ( "expected wait plan" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadBatchErr : errors . New ( "load batch failed" ) ,
acquireResults : map [ int64 ] bool { 1 : false } ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . WaitPlan == nil {
t . Fatalf ( "expected wait plan" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo ( t * testing . T ) {
groupID := int64 ( 1 )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadMap : map [ int64 ] * AccountLoadInfo {
1 : { AccountID : 1 , LoadRate : 50 } ,
} ,
skipDefaultLoad : true ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil || selection . Account . ID != 2 {
t . Fatalf ( "expected account 2" )
}
}
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed ( t * testing . T ) {
oldTime := time . Now ( ) . Add ( - 2 * time . Hour )
newTime := time . Now ( ) . Add ( - 1 * time . Hour )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Priority : 1 , LastUsedAt : & newTime } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Priority : 1 , LastUsedAt : & oldTime } ,
} ,
}
cache := & stubGatewayCache { }
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
}
acc , err := svc . SelectAccountForModelWithExclusions ( context . Background ( ) , nil , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountForModelWithExclusions error: %v" , err )
}
if acc == nil || acc . ID != 2 {
t . Fatalf ( "expected account 2" )
}
}
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed ( t * testing . T ) {
groupID := int64 ( 1 )
lastUsed := time . Now ( ) . Add ( - 1 * time . Hour )
repo := stubOpenAIAccountRepo {
accounts : [ ] Account {
{ ID : 1 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 , LastUsedAt : & lastUsed } ,
{ ID : 2 , Platform : PlatformOpenAI , Status : StatusActive , Schedulable : true , Concurrency : 1 , Priority : 1 } ,
} ,
}
cache := & stubGatewayCache { }
concurrencyCache := stubConcurrencyCache {
loadMap : map [ int64 ] * AccountLoadInfo {
1 : { AccountID : 1 , LoadRate : 10 } ,
2 : { AccountID : 2 , LoadRate : 10 } ,
} ,
}
svc := & OpenAIGatewayService {
accountRepo : repo ,
cache : cache ,
concurrencyService : NewConcurrencyService ( concurrencyCache ) ,
}
selection , err := svc . SelectAccountWithLoadAwareness ( context . Background ( ) , & groupID , "" , "gpt-4" , nil )
if err != nil {
t . Fatalf ( "SelectAccountWithLoadAwareness error: %v" , err )
}
if selection == nil || selection . Account == nil || selection . Account . ID != 2 {
t . Fatalf ( "expected account 2" )
}
}
2026-01-04 19:49:59 +08:00
func TestOpenAIStreamingTimeout ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 1 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
start := time . Now ( )
_ , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , start , "model" , "model" )
_ = pw . Close ( )
_ = pr . Close ( )
if err == nil || ! strings . Contains ( err . Error ( ) , "stream data interval timeout" ) {
t . Fatalf ( "expected stream timeout error, got %v" , err )
}
2026-01-19 13:53:39 +08:00
if ! strings . Contains ( rec . Body . String ( ) , "\"type\":\"error\"" ) || ! strings . Contains ( rec . Body . String ( ) , "stream_timeout" ) {
t . Fatalf ( "expected OpenAI-compatible error SSE event, got %q" , rec . Body . String ( ) )
2026-01-04 19:49:59 +08:00
}
}
2026-03-12 16:53:18 +08:00
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent ( t * testing . T ) {
2026-01-15 21:42:13 +08:00
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
ctx , cancel := context . WithCancel ( context . Background ( ) )
cancel ( )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil ) . WithContext ( ctx )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : cancelReadCloser { } ,
Header : http . Header { } ,
}
_ , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) , "model" , "model" )
2026-03-12 16:53:18 +08:00
if err == nil || ! strings . Contains ( err . Error ( ) , "stream usage incomplete" ) {
t . Fatalf ( "expected incomplete stream error, got %v" , err )
2026-01-15 21:42:13 +08:00
}
if strings . Contains ( rec . Body . String ( ) , "event: error" ) || strings . Contains ( rec . Body . String ( ) , "stream_read_error" ) {
t . Fatalf ( "expected no injected SSE error event, got %q" , rec . Body . String ( ) )
}
}
2026-01-15 21:51:14 +08:00
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
c . Writer = & failingGinWriter { ResponseWriter : c . Writer , failAfter : 0 }
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n" ) )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n" ) )
} ( )
result , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) , "model" , "model" )
_ = pr . Close ( )
if err != nil {
t . Fatalf ( "expected nil error, got %v" , err )
}
if result == nil || result . usage == nil {
t . Fatalf ( "expected usage result" )
}
if result . usage . InputTokens != 3 || result . usage . OutputTokens != 5 || result . usage . CacheReadInputTokens != 1 {
t . Fatalf ( "unexpected usage: %+v" , * result . usage )
}
if strings . Contains ( rec . Body . String ( ) , "event: error" ) || strings . Contains ( rec . Body . String ( ) , "write_failed" ) {
t . Fatalf ( "expected no injected SSE error event, got %q" , rec . Body . String ( ) )
}
}
2026-03-12 16:53:18 +08:00
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n" ) )
} ( )
_ , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) , "model" , "model" )
_ = pr . Close ( )
if err == nil || ! strings . Contains ( err . Error ( ) , "missing terminal event" ) {
t . Fatalf ( "expected missing terminal event error, got %v" , err )
}
}
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n" ) )
} ( )
_ , err := svc . handleStreamingResponsePassthrough ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) )
_ = pr . Close ( )
if err == nil || ! strings . Contains ( err . Error ( ) , "missing terminal event" ) {
t . Fatalf ( "expected missing terminal event error, got %v" , err )
}
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n" ) )
} ( )
result , err := svc . handleStreamingResponsePassthrough ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) )
_ = pr . Close ( )
require . NoError ( t , err )
require . NotNil ( t , result )
require . NotNil ( t , result . usage )
require . Equal ( t , 2 , result . usage . InputTokens )
require . Equal ( t , 3 , result . usage . OutputTokens )
require . Equal ( t , 1 , result . usage . CacheReadInputTokens )
}
2026-01-04 19:49:59 +08:00
func TestOpenAIStreamingTooLong ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : 64 * 1024 ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
2026-01-04 22:10:32 +08:00
defer func ( ) { _ = pw . Close ( ) } ( )
2026-01-04 19:49:59 +08:00
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings . Repeat ( "a" , 128 * 1024 ) + "\n"
_ , _ = pw . Write ( [ ] byte ( payload ) )
} ( )
_ , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 2 } , time . Now ( ) , "model" , "model" )
_ = pr . Close ( )
if ! errors . Is ( err , bufio . ErrTooLong ) {
t . Fatalf ( "expected ErrTooLong, got %v" , err )
}
2026-01-19 13:53:39 +08:00
if ! strings . Contains ( rec . Body . String ( ) , "\"type\":\"error\"" ) || ! strings . Contains ( rec . Body . String ( ) , "response_too_large" ) {
t . Fatalf ( "expected OpenAI-compatible error SSE event, got %q" , rec . Body . String ( ) )
2026-01-04 19:49:59 +08:00
}
}
2026-01-05 13:54:43 +08:00
func TestOpenAINonStreamingContentTypePassThrough ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Security : config . SecurityConfig {
ResponseHeaders : config . ResponseHeaderConfig { Enabled : false } ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
body := [ ] byte ( ` { "usage": { "input_tokens":1,"output_tokens":2,"input_tokens_details": { "cached_tokens":0}}} ` )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : io . NopCloser ( bytes . NewReader ( body ) ) ,
Header : http . Header { "Content-Type" : [ ] string { "application/vnd.test+json" } } ,
}
_ , err := svc . handleNonStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { } , "model" , "model" )
if err != nil {
t . Fatalf ( "handleNonStreamingResponse error: %v" , err )
}
if ! strings . Contains ( rec . Header ( ) . Get ( "Content-Type" ) , "application/vnd.test+json" ) {
t . Fatalf ( "expected Content-Type passthrough, got %q" , rec . Header ( ) . Get ( "Content-Type" ) )
}
}
func TestOpenAINonStreamingContentTypeDefault ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Security : config . SecurityConfig {
ResponseHeaders : config . ResponseHeaderConfig { Enabled : false } ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
body := [ ] byte ( ` { "usage": { "input_tokens":1,"output_tokens":2,"input_tokens_details": { "cached_tokens":0}}} ` )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : io . NopCloser ( bytes . NewReader ( body ) ) ,
Header : http . Header { } ,
}
_ , err := svc . handleNonStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { } , "model" , "model" )
if err != nil {
t . Fatalf ( "handleNonStreamingResponse error: %v" , err )
}
if ! strings . Contains ( rec . Header ( ) . Get ( "Content-Type" ) , "application/json" ) {
t . Fatalf ( "expected default Content-Type, got %q" , rec . Header ( ) . Get ( "Content-Type" ) )
}
}
func TestOpenAIStreamingHeadersOverride ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Security : config . SecurityConfig {
ResponseHeaders : config . ResponseHeaderConfig { Enabled : false } ,
} ,
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header {
"Cache-Control" : [ ] string { "upstream" } ,
2026-01-05 14:41:08 +08:00
"X-Request-Id" : [ ] string { "req-123" } ,
2026-01-05 13:54:43 +08:00
"Content-Type" : [ ] string { "application/custom" } ,
} ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
2026-03-12 16:53:18 +08:00
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.completed\",\"response\":{}}\n\n" ) )
2026-01-05 13:54:43 +08:00
} ( )
_ , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) , "model" , "model" )
_ = pr . Close ( )
if err != nil {
t . Fatalf ( "handleStreamingResponse error: %v" , err )
}
if rec . Header ( ) . Get ( "Cache-Control" ) != "no-cache" {
t . Fatalf ( "expected Cache-Control override, got %q" , rec . Header ( ) . Get ( "Cache-Control" ) )
}
if rec . Header ( ) . Get ( "Content-Type" ) != "text/event-stream" {
t . Fatalf ( "expected Content-Type override, got %q" , rec . Header ( ) . Get ( "Content-Type" ) )
}
2026-01-05 14:41:08 +08:00
if rec . Header ( ) . Get ( "X-Request-Id" ) != "req-123" {
t . Fatalf ( "expected X-Request-Id passthrough, got %q" , rec . Header ( ) . Get ( "X-Request-Id" ) )
2026-01-05 13:54:43 +08:00
}
}
2026-02-06 22:55:12 +08:00
func TestOpenAIStreamingReuseScannerBufferAndStillWorks ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Gateway : config . GatewayConfig {
StreamDataIntervalTimeout : 0 ,
StreamKeepaliveInterval : 0 ,
MaxLineSize : defaultMaxLineSize ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
pr , pw := io . Pipe ( )
resp := & http . Response {
StatusCode : http . StatusOK ,
Body : pr ,
Header : http . Header { } ,
}
go func ( ) {
defer func ( ) { _ = pw . Close ( ) } ( )
_ , _ = pw . Write ( [ ] byte ( "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n" ) )
} ( )
result , err := svc . handleStreamingResponse ( c . Request . Context ( ) , resp , c , & Account { ID : 1 } , time . Now ( ) , "model" , "model" )
_ = pr . Close ( )
require . NoError ( t , err )
require . NotNil ( t , result )
require . NotNil ( t , result . usage )
require . Equal ( t , 1 , result . usage . InputTokens )
require . Equal ( t , 2 , result . usage . OutputTokens )
require . Equal ( t , 3 , result . usage . CacheReadInputTokens )
}
2026-01-05 13:54:43 +08:00
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
cfg := & config . Config {
Security : config . SecurityConfig {
URLAllowlist : config . URLAllowlistConfig { Enabled : false } ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
account := & Account {
Platform : PlatformOpenAI ,
Type : AccountTypeAPIKey ,
Credentials : map [ string ] any { "base_url" : "://invalid-url" } ,
}
2026-01-10 20:53:16 +08:00
_ , err := svc . buildUpstreamRequest ( c . Request . Context ( ) , c , account , [ ] byte ( "{}" ) , "token" , false , "" , false )
2026-01-05 13:54:43 +08:00
if err == nil {
t . Fatalf ( "expected error for invalid base_url when allowlist disabled" )
}
}
2026-01-05 14:41:08 +08:00
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS ( t * testing . T ) {
2026-01-05 13:54:43 +08:00
cfg := & config . Config {
Security : config . SecurityConfig {
URLAllowlist : config . URLAllowlistConfig { Enabled : false } ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
2026-01-05 14:41:08 +08:00
if _ , err := svc . validateUpstreamBaseURL ( "http://not-https.example.com" ) ; err == nil {
t . Fatalf ( "expected http to be rejected when allow_insecure_http is false" )
}
normalized , err := svc . validateUpstreamBaseURL ( "https://example.com" )
if err != nil {
t . Fatalf ( "expected https to be allowed when allowlist disabled, got %v" , err )
}
if normalized != "https://example.com" {
t . Fatalf ( "expected raw url passthrough, got %q" , normalized )
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP ( t * testing . T ) {
cfg := & config . Config {
Security : config . SecurityConfig {
URLAllowlist : config . URLAllowlistConfig {
Enabled : false ,
AllowInsecureHTTP : true ,
} ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
2026-01-05 13:54:43 +08:00
normalized , err := svc . validateUpstreamBaseURL ( "http://not-https.example.com" )
if err != nil {
2026-01-05 14:41:08 +08:00
t . Fatalf ( "expected http allowed when allow_insecure_http is true, got %v" , err )
2026-01-05 13:54:43 +08:00
}
if normalized != "http://not-https.example.com" {
t . Fatalf ( "expected raw url passthrough, got %q" , normalized )
}
}
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist ( t * testing . T ) {
cfg := & config . Config {
Security : config . SecurityConfig {
URLAllowlist : config . URLAllowlistConfig {
Enabled : true ,
UpstreamHosts : [ ] string { "example.com" } ,
} ,
} ,
}
svc := & OpenAIGatewayService { cfg : cfg }
if _ , err := svc . validateUpstreamBaseURL ( "https://example.com" ) ; err != nil {
t . Fatalf ( "expected allowlisted host to pass, got %v" , err )
}
if _ , err := svc . validateUpstreamBaseURL ( "https://evil.com" ) ; err == nil {
t . Fatalf ( "expected non-allowlisted host to fail" )
}
}
2026-02-07 17:09:55 +08:00
2026-03-06 20:46:10 +08:00
func TestOpenAIUpdateCodexUsageSnapshotFromHeaders ( t * testing . T ) {
repo := & snapshotUpdateAccountRepo { updateExtraCalls : make ( chan map [ string ] any , 1 ) }
svc := & OpenAIGatewayService { accountRepo : repo }
headers := http . Header { }
headers . Set ( "x-codex-primary-used-percent" , "12" )
headers . Set ( "x-codex-secondary-used-percent" , "34" )
headers . Set ( "x-codex-primary-window-minutes" , "300" )
headers . Set ( "x-codex-secondary-window-minutes" , "10080" )
headers . Set ( "x-codex-primary-reset-after-seconds" , "600" )
headers . Set ( "x-codex-secondary-reset-after-seconds" , "86400" )
svc . UpdateCodexUsageSnapshotFromHeaders ( context . Background ( ) , 123 , headers )
select {
case updates := <- repo . updateExtraCalls :
require . Equal ( t , 12.0 , updates [ "codex_5h_used_percent" ] )
require . Equal ( t , 34.0 , updates [ "codex_7d_used_percent" ] )
require . Equal ( t , 600 , updates [ "codex_5h_reset_after_seconds" ] )
require . Equal ( t , 86400 , updates [ "codex_7d_reset_after_seconds" ] )
case <- time . After ( 2 * time . Second ) :
t . Fatal ( "expected UpdateExtra to be called" )
}
}
2026-03-06 18:50:28 +08:00
func TestOpenAIResponsesRequestPathSuffix ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
tests := [ ] struct {
name string
path string
want string
} {
{ name : "exact v1 responses" , path : "/v1/responses" , want : "" } ,
{ name : "compact v1 responses" , path : "/v1/responses/compact" , want : "/compact" } ,
{ name : "compact alias responses" , path : "/responses/compact/" , want : "/compact" } ,
{ name : "nested suffix" , path : "/openai/v1/responses/compact/detail" , want : "/compact/detail" } ,
{ name : "unrelated path" , path : "/v1/chat/completions" , want : "" } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c . Request = httptest . NewRequest ( http . MethodPost , tt . path , nil )
require . Equal ( t , tt . want , openAIResponsesRequestPathSuffix ( c ) )
} )
}
}
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/v1/responses/compact" , bytes . NewReader ( [ ] byte ( ` { "model":"gpt-5"} ` ) ) )
svc := & OpenAIGatewayService { }
account := & Account { Type : AccountTypeOAuth }
req , err := svc . buildUpstreamRequestOpenAIPassthrough ( c . Request . Context ( ) , c , account , [ ] byte ( ` { "model":"gpt-5"} ` ) , "token" )
require . NoError ( t , err )
require . Equal ( t , chatgptCodexURL + "/compact" , req . URL . String ( ) )
require . Equal ( t , "application/json" , req . Header . Get ( "Accept" ) )
require . Equal ( t , codexCLIVersion , req . Header . Get ( "Version" ) )
require . NotEmpty ( t , req . Header . Get ( "Session_Id" ) )
}
func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/v1/responses/compact" , bytes . NewReader ( [ ] byte ( ` { "model":"gpt-5"} ` ) ) )
svc := & OpenAIGatewayService { }
account := & Account {
Type : AccountTypeOAuth ,
Credentials : map [ string ] any { "chatgpt_account_id" : "chatgpt-acc" } ,
}
req , err := svc . buildUpstreamRequest ( c . Request . Context ( ) , c , account , [ ] byte ( ` { "model":"gpt-5"} ` ) , "token" , false , "" , true )
require . NoError ( t , err )
require . Equal ( t , chatgptCodexURL + "/compact" , req . URL . String ( ) )
require . Equal ( t , "application/json" , req . Header . Get ( "Accept" ) )
require . Equal ( t , codexCLIVersion , req . Header . Get ( "Version" ) )
require . NotEmpty ( t , req . Header . Get ( "Session_Id" ) )
}
func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/responses/compact" , bytes . NewReader ( [ ] byte ( ` { "model":"gpt-5"} ` ) ) )
svc := & OpenAIGatewayService { cfg : & config . Config {
Security : config . SecurityConfig {
URLAllowlist : config . URLAllowlistConfig { Enabled : false } ,
} ,
} }
account := & Account {
Type : AccountTypeAPIKey ,
Platform : PlatformOpenAI ,
Credentials : map [ string ] any { "base_url" : "https://example.com/v1" } ,
}
req , err := svc . buildUpstreamRequest ( c . Request . Context ( ) , c , account , [ ] byte ( ` { "model":"gpt-5"} ` ) , "token" , false , "" , false )
require . NoError ( t , err )
require . Equal ( t , "https://example.com/v1/responses/compact" , req . URL . String ( ) )
}
2026-03-07 14:12:38 +08:00
func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility ( t * testing . T ) {
gin . SetMode ( gin . TestMode )
tests := [ ] struct {
name string
userAgent string
originator string
wantOriginator string
} {
{ name : "desktop originator preserved" , originator : "Codex Desktop" , wantOriginator : "Codex Desktop" } ,
{ name : "vscode originator preserved" , originator : "codex_vscode" , wantOriginator : "codex_vscode" } ,
{ name : "official ua fallback to codex_cli_rs" , userAgent : "Codex Desktop/1.2.3" , wantOriginator : "codex_cli_rs" } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/v1/responses" , bytes . NewReader ( [ ] byte ( ` { "model":"gpt-5"} ` ) ) )
if tt . userAgent != "" {
c . Request . Header . Set ( "User-Agent" , tt . userAgent )
}
if tt . originator != "" {
c . Request . Header . Set ( "originator" , tt . originator )
}
svc := & OpenAIGatewayService { }
account := & Account {
Type : AccountTypeOAuth ,
Credentials : map [ string ] any { "chatgpt_account_id" : "chatgpt-acc" } ,
}
isCodexCLI := openai . IsCodexOfficialClientByHeaders ( c . GetHeader ( "User-Agent" ) , c . GetHeader ( "originator" ) )
req , err := svc . buildUpstreamRequest ( c . Request . Context ( ) , c , account , [ ] byte ( ` { "model":"gpt-5"} ` ) , "token" , false , "" , isCodexCLI )
require . NoError ( t , err )
require . Equal ( t , tt . wantOriginator , req . Header . Get ( "originator" ) )
} )
}
}
2026-02-07 17:09:55 +08:00
// ==================== P1-08 修复: model 替换性能优化测试 ====================
2026-03-06 20:46:10 +08:00
// ==================== P1-08 修复: model 替换性能优化测试 =============
2026-02-07 17:09:55 +08:00
func TestReplaceModelInSSELine ( t * testing . T ) {
svc := & OpenAIGatewayService { }
tests := [ ] struct {
name string
line string
from string
to string
expected string
} {
{
name : "顶层 model 字段替换" ,
line : ` data: { "id":"chatcmpl-123","model":"gpt-4o","choices":[]} ` ,
from : "gpt-4o" ,
to : "my-custom-model" ,
expected : ` data: { "id":"chatcmpl-123","model":"my-custom-model","choices":[]} ` ,
} ,
{
name : "嵌套 response.model 替换" ,
line : ` data: { "type":"response","response": { "id":"resp-1","model":"gpt-4o","output":[]}} ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: { "type":"response","response": { "id":"resp-1","model":"my-model","output":[]}} ` ,
} ,
{
name : "model 不匹配时不替换" ,
line : ` data: { "id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]} ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: { "id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]} ` ,
} ,
{
name : "无 model 字段时不替换" ,
line : ` data: { "id":"chatcmpl-123","choices":[]} ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: { "id":"chatcmpl-123","choices":[]} ` ,
} ,
{
name : "空 data 行" ,
line : ` data: ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: ` ,
} ,
{
name : "[DONE] 行" ,
line : ` data: [DONE] ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: [DONE] ` ,
} ,
{
name : "非 data: 前缀行" ,
line : ` event: message ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` event: message ` ,
} ,
{
name : "非法 JSON 不替换" ,
line : ` data: { invalid json} ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: { invalid json} ` ,
} ,
{
name : "无空格 data: 格式" ,
line : ` data: { "id":"x","model":"gpt-4o"} ` ,
from : "gpt-4o" ,
to : "my-model" ,
expected : ` data: { "id":"x","model":"my-model"} ` ,
} ,
{
name : "model 名含特殊字符" ,
line : ` data: { "model":"org/model-v2.1-beta"} ` ,
from : "org/model-v2.1-beta" ,
to : "custom/alias" ,
expected : ` data: { "model":"custom/alias"} ` ,
} ,
{
name : "空行" ,
line : "" ,
from : "gpt-4o" ,
to : "my-model" ,
expected : "" ,
} ,
{
name : "保持其他字段不变" ,
line : ` data: { "id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[ { "index":0,"delta": { "content":"hi"}}]} ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` data: { "id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[ { "index":0,"delta": { "content":"hi"}}]} ` ,
} ,
{
name : "顶层优先于嵌套:同时存在两个 model" ,
line : ` data: { "model":"gpt-4o","response": { "model":"gpt-4o"}} ` ,
from : "gpt-4o" ,
to : "replaced" ,
expected : ` data: { "model":"replaced","response": { "model":"gpt-4o"}} ` ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got := svc . replaceModelInSSELine ( tt . line , tt . from , tt . to )
require . Equal ( t , tt . expected , got )
} )
}
}
func TestReplaceModelInSSEBody ( t * testing . T ) {
svc := & OpenAIGatewayService { }
tests := [ ] struct {
name string
body string
from string
to string
expected string
} {
{
name : "多行 SSE body 替换" ,
body : "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n" ,
from : "gpt-4o" ,
to : "alias" ,
expected : "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n" ,
} ,
{
name : "无需替换的 body" ,
body : "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n" ,
from : "gpt-4o" ,
to : "alias" ,
expected : "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n" ,
} ,
{
name : "混合 event 和 data 行" ,
body : "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n" ,
from : "gpt-4o" ,
to : "alias" ,
expected : "event: message\ndata: {\"model\":\"alias\"}\n\n" ,
} ,
{
name : "空 body" ,
body : "" ,
from : "gpt-4o" ,
to : "alias" ,
expected : "" ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got := svc . replaceModelInSSEBody ( tt . body , tt . from , tt . to )
require . Equal ( t , tt . expected , got )
} )
}
}
func TestReplaceModelInResponseBody ( t * testing . T ) {
svc := & OpenAIGatewayService { }
tests := [ ] struct {
name string
body string
from string
to string
expected string
} {
{
name : "替换顶层 model" ,
body : ` { "id":"chatcmpl-123","model":"gpt-4o","choices":[]} ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` { "id":"chatcmpl-123","model":"alias","choices":[]} ` ,
} ,
{
name : "model 不匹配不替换" ,
body : ` { "id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]} ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` { "id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]} ` ,
} ,
{
name : "无 model 字段不替换" ,
body : ` { "id":"chatcmpl-123","choices":[]} ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` { "id":"chatcmpl-123","choices":[]} ` ,
} ,
{
name : "非法 JSON 返回原值" ,
body : ` not json ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` not json ` ,
} ,
{
name : "空 body 返回原值" ,
body : ` ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` ` ,
} ,
{
name : "保持嵌套结构不变" ,
body : ` { "model":"gpt-4o","usage": { "prompt_tokens":10,"completion_tokens":20},"choices":[ { "message": { "role":"assistant","content":"hello"}}]} ` ,
from : "gpt-4o" ,
to : "alias" ,
expected : ` { "model":"alias","usage": { "prompt_tokens":10,"completion_tokens":20},"choices":[ { "message": { "role":"assistant","content":"hello"}}]} ` ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got := svc . replaceModelInResponseBody ( [ ] byte ( tt . body ) , tt . from , tt . to )
require . Equal ( t , tt . expected , string ( got ) )
} )
}
}
2026-02-12 09:41:37 +08:00
func TestExtractOpenAISSEDataLine ( t * testing . T ) {
tests := [ ] struct {
name string
line string
wantData string
wantOK bool
} {
{ name : "标准格式" , line : ` data: { "type":"x"} ` , wantData : ` { "type":"x"} ` , wantOK : true } ,
{ name : "无空格格式" , line : ` data: { "type":"x"} ` , wantData : ` { "type":"x"} ` , wantOK : true } ,
{ name : "纯空数据" , line : ` data: ` , wantData : ` ` , wantOK : true } ,
{ name : "非 data 行" , line : ` event: message ` , wantData : ` ` , wantOK : false } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got , ok := extractOpenAISSEDataLine ( tt . line )
require . Equal ( t , tt . wantOK , ok )
require . Equal ( t , tt . wantData , got )
} )
}
}
func TestParseSSEUsage_SelectiveParsing ( t * testing . T ) {
svc := & OpenAIGatewayService { }
usage := & OpenAIUsage { InputTokens : 9 , OutputTokens : 8 , CacheReadInputTokens : 7 }
// 非 completed 事件,不应覆盖 usage
svc . parseSSEUsage ( ` { "type":"response.in_progress","response": { "usage": { "input_tokens":1,"output_tokens":2}}} ` , usage )
require . Equal ( t , 9 , usage . InputTokens )
require . Equal ( t , 8 , usage . OutputTokens )
require . Equal ( t , 7 , usage . CacheReadInputTokens )
// completed 事件,应提取 usage
svc . parseSSEUsage ( ` { "type":"response.completed","response": { "usage": { "input_tokens":3,"output_tokens":5,"input_tokens_details": { "cached_tokens":2}}}} ` , usage )
require . Equal ( t , 3 , usage . InputTokens )
require . Equal ( t , 5 , usage . OutputTokens )
require . Equal ( t , 2 , usage . CacheReadInputTokens )
2026-03-12 16:53:18 +08:00
// done 事件同样可能携带最终 usage
svc . parseSSEUsage ( ` { "type":"response.done","response": { "usage": { "input_tokens":13,"output_tokens":15,"input_tokens_details": { "cached_tokens":4}}}} ` , usage )
require . Equal ( t , 13 , usage . InputTokens )
require . Equal ( t , 15 , usage . OutputTokens )
require . Equal ( t , 4 , usage . CacheReadInputTokens )
2026-02-12 09:41:37 +08:00
}
func TestExtractCodexFinalResponse_SampleReplay ( t * testing . T ) {
body := strings . Join ( [ ] string {
` event: message ` ,
` data: { "type":"response.in_progress","response": { "id":"resp_1"}} ` ,
` data: { "type":"response.completed","response": { "id":"resp_1","model":"gpt-4o","usage": { "input_tokens":11,"output_tokens":22,"input_tokens_details": { "cached_tokens":3}}}} ` ,
` data: [DONE] ` ,
} , "\n" )
finalResp , ok := extractCodexFinalResponse ( body )
require . True ( t , ok )
require . Contains ( t , string ( finalResp ) , ` "id":"resp_1" ` )
require . Contains ( t , string ( finalResp ) , ` "input_tokens":11 ` )
}
2026-04-07 22:49:14 +08:00
func TestHandleSSEToJSON_CompletedEventReturnsJSON ( t * testing . T ) {
2026-02-12 09:41:37 +08:00
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
svc := & OpenAIGatewayService { cfg : & config . Config { } }
resp := & http . Response {
StatusCode : http . StatusOK ,
Header : http . Header { "Content-Type" : [ ] string { "text/event-stream" } } ,
}
body := [ ] byte ( strings . Join ( [ ] string {
` data: { "type":"response.in_progress","response": { "id":"resp_2"}} ` ,
` data: { "type":"response.completed","response": { "id":"resp_2","model":"gpt-4o","usage": { "input_tokens":7,"output_tokens":9,"input_tokens_details": { "cached_tokens":1}}}} ` ,
` data: [DONE] ` ,
} , "\n" ) )
2026-04-07 22:49:14 +08:00
usage , err := svc . handleSSEToJSON ( resp , c , body , "gpt-4o" , "gpt-4o" )
2026-02-12 09:41:37 +08:00
require . NoError ( t , err )
require . NotNil ( t , usage )
require . Equal ( t , 7 , usage . InputTokens )
require . Equal ( t , 9 , usage . OutputTokens )
require . Equal ( t , 1 , usage . CacheReadInputTokens )
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
require . NotContains ( t , rec . Body . String ( ) , "event:" )
require . Contains ( t , rec . Body . String ( ) , ` "id":"resp_2" ` )
require . NotContains ( t , rec . Body . String ( ) , "data:" )
}
2026-04-07 22:49:14 +08:00
func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody ( t * testing . T ) {
2026-02-12 09:41:37 +08:00
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
svc := & OpenAIGatewayService { cfg : & config . Config { } }
resp := & http . Response {
StatusCode : http . StatusOK ,
Header : http . Header { "Content-Type" : [ ] string { "text/event-stream" } } ,
}
body := [ ] byte ( strings . Join ( [ ] string {
` data: { "type":"response.in_progress","response": { "id":"resp_3"}} ` ,
` data: [DONE] ` ,
} , "\n" ) )
2026-04-07 22:49:14 +08:00
usage , err := svc . handleSSEToJSON ( resp , c , body , "gpt-4o" , "gpt-4o" )
2026-02-12 09:41:37 +08:00
require . NoError ( t , err )
require . NotNil ( t , usage )
require . Equal ( t , 0 , usage . InputTokens )
require . Contains ( t , rec . Header ( ) . Get ( "Content-Type" ) , "text/event-stream" )
require . Contains ( t , rec . Body . String ( ) , ` data: { "type":"response.in_progress" ` )
}
2026-03-06 14:54:52 +08:00
2026-04-07 22:49:14 +08:00
func TestHandleSSEToJSON_ResponseFailedReturnsProtocolError ( t * testing . T ) {
2026-03-06 14:54:52 +08:00
gin . SetMode ( gin . TestMode )
rec := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( rec )
c . Request = httptest . NewRequest ( http . MethodPost , "/" , nil )
svc := & OpenAIGatewayService { cfg : & config . Config { } }
resp := & http . Response {
StatusCode : http . StatusOK ,
Header : http . Header { "Content-Type" : [ ] string { "text/event-stream" } } ,
}
body := [ ] byte ( strings . Join ( [ ] string {
` data: { "type":"response.failed","error": { "message":"upstream rejected request"}} ` ,
` data: [DONE] ` ,
} , "\n" ) )
2026-04-07 22:49:14 +08:00
usage , err := svc . handleSSEToJSON ( resp , c , body , "gpt-4o" , "gpt-4o" )
2026-03-06 14:54:52 +08:00
require . Nil ( t , usage )
require . Error ( t , err )
require . Equal ( t , http . StatusBadGateway , rec . Code )
require . Contains ( t , rec . Body . String ( ) , "upstream rejected request" )
require . Contains ( t , rec . Header ( ) . Get ( "Content-Type" ) , "application/json" )
}