2026-01-04 19:49:59 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"bufio"
|
2026-01-05 13:54:43 +08:00
|
|
|
|
"bytes"
|
2026-01-13 22:49:26 -08:00
|
|
|
|
"context"
|
2026-01-04 19:49:59 +08:00
|
|
|
|
"errors"
|
2026-02-28 15:01:20 +08:00
|
|
|
|
"fmt"
|
2026-01-04 19:49:59 +08:00
|
|
|
|
"io"
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"testing"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
2026-03-07 14:12:38 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
2026-02-28 15:01:20 +08:00
|
|
|
|
"github.com/cespare/xxhash/v2"
|
2026-01-04 19:49:59 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
2026-02-06 22:55:12 +08:00
|
|
|
|
"github.com/stretchr/testify/require"
|
2026-01-04 19:49:59 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-08 12:05:39 +08:00
|
|
|
|
// 编译期接口断言
|
|
|
|
|
|
var _ AccountRepository = (*stubOpenAIAccountRepo)(nil)
|
|
|
|
|
|
var _ GatewayCache = (*stubGatewayCache)(nil)
|
|
|
|
|
|
|
2026-01-13 22:49:26 -08:00
|
|
|
|
type stubOpenAIAccountRepo struct {
|
|
|
|
|
|
AccountRepository
|
|
|
|
|
|
accounts []Account
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-06 20:46:10 +08:00
|
|
|
|
type snapshotUpdateAccountRepo struct {
|
|
|
|
|
|
stubOpenAIAccountRepo
|
|
|
|
|
|
updateExtraCalls chan map[string]any
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *snapshotUpdateAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
|
|
|
|
|
if r.updateExtraCalls != nil {
|
|
|
|
|
|
copied := make(map[string]any, len(updates))
|
|
|
|
|
|
for k, v := range updates {
|
|
|
|
|
|
copied[k] = v
|
|
|
|
|
|
}
|
|
|
|
|
|
r.updateExtraCalls <- copied
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-20 11:19:32 +08:00
|
|
|
|
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
|
|
|
|
|
for i := range r.accounts {
|
|
|
|
|
|
if r.accounts[i].ID == id {
|
|
|
|
|
|
return &r.accounts[i], nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, errors.New("account not found")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 22:49:26 -08:00
|
|
|
|
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
2026-01-20 11:19:32 +08:00
|
|
|
|
var result []Account
|
|
|
|
|
|
for _, acc := range r.accounts {
|
|
|
|
|
|
if acc.Platform == platform {
|
|
|
|
|
|
result = append(result, acc)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return result, nil
|
2026-01-13 22:49:26 -08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
2026-01-20 11:19:32 +08:00
|
|
|
|
var result []Account
|
|
|
|
|
|
for _, acc := range r.accounts {
|
|
|
|
|
|
if acc.Platform == platform {
|
|
|
|
|
|
result = append(result, acc)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return result, nil
|
2026-01-13 22:49:26 -08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-03 13:10:26 +08:00
|
|
|
|
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
|
|
|
|
|
return r.ListSchedulableByPlatform(ctx, platform)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 22:49:26 -08:00
|
|
|
|
type stubConcurrencyCache struct {
|
|
|
|
|
|
ConcurrencyCache
|
2026-01-20 11:19:32 +08:00
|
|
|
|
loadBatchErr error
|
|
|
|
|
|
loadMap map[int64]*AccountLoadInfo
|
|
|
|
|
|
acquireResults map[int64]bool
|
|
|
|
|
|
waitCounts map[int64]int
|
|
|
|
|
|
skipDefaultLoad bool
|
2026-01-13 22:49:26 -08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 21:42:13 +08:00
|
|
|
|
type cancelReadCloser struct{}
|
|
|
|
|
|
|
|
|
|
|
|
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
|
|
|
|
|
func (c cancelReadCloser) Close() error { return nil }
|
|
|
|
|
|
|
2026-01-15 21:51:14 +08:00
|
|
|
|
type failingGinWriter struct {
|
|
|
|
|
|
gin.ResponseWriter
|
|
|
|
|
|
failAfter int
|
|
|
|
|
|
writes int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
|
|
|
|
|
if w.writes >= w.failAfter {
|
|
|
|
|
|
return 0, errors.New("write failed")
|
|
|
|
|
|
}
|
|
|
|
|
|
w.writes++
|
|
|
|
|
|
return w.ResponseWriter.Write(p)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 22:49:26 -08:00
|
|
|
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
2026-01-20 11:19:32 +08:00
|
|
|
|
if c.acquireResults != nil {
|
|
|
|
|
|
if result, ok := c.acquireResults[accountID]; ok {
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
return true, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
2026-01-20 11:19:32 +08:00
|
|
|
|
if c.loadBatchErr != nil {
|
|
|
|
|
|
return nil, c.loadBatchErr
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
2026-01-20 11:19:32 +08:00
|
|
|
|
if c.skipDefaultLoad && c.loadMap != nil {
|
|
|
|
|
|
for _, acc := range accounts {
|
|
|
|
|
|
if load, ok := c.loadMap[acc.ID]; ok {
|
|
|
|
|
|
out[acc.ID] = load
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return out, nil
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
for _, acc := range accounts {
|
2026-01-20 11:19:32 +08:00
|
|
|
|
if c.loadMap != nil {
|
|
|
|
|
|
if load, ok := c.loadMap[acc.ID]; ok {
|
|
|
|
|
|
out[acc.ID] = load
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
|
|
|
|
|
}
|
|
|
|
|
|
return out, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-17 02:31:16 +08:00
|
|
|
|
func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
2026-02-10 08:59:30 +08:00
|
|
|
|
bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`)
|
|
|
|
|
|
|
2026-01-17 02:31:16 +08:00
|
|
|
|
// 1) session_id header wins
|
|
|
|
|
|
c.Request.Header.Set("session_id", "sess-123")
|
|
|
|
|
|
c.Request.Header.Set("conversation_id", "conv-456")
|
2026-02-10 08:59:30 +08:00
|
|
|
|
h1 := svc.GenerateSessionHash(c, bodyWithKey)
|
2026-01-17 02:31:16 +08:00
|
|
|
|
if h1 == "" {
|
|
|
|
|
|
t.Fatalf("expected non-empty hash")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) conversation_id used when session_id absent
|
|
|
|
|
|
c.Request.Header.Del("session_id")
|
2026-02-10 08:59:30 +08:00
|
|
|
|
h2 := svc.GenerateSessionHash(c, bodyWithKey)
|
2026-01-17 02:31:16 +08:00
|
|
|
|
if h2 == "" {
|
|
|
|
|
|
t.Fatalf("expected non-empty hash")
|
|
|
|
|
|
}
|
|
|
|
|
|
if h1 == h2 {
|
|
|
|
|
|
t.Fatalf("expected different hashes for different keys")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3) prompt_cache_key used when both headers absent
|
|
|
|
|
|
c.Request.Header.Del("conversation_id")
|
2026-02-10 08:59:30 +08:00
|
|
|
|
h3 := svc.GenerateSessionHash(c, bodyWithKey)
|
2026-01-17 02:31:16 +08:00
|
|
|
|
if h3 == "" {
|
|
|
|
|
|
t.Fatalf("expected non-empty hash")
|
|
|
|
|
|
}
|
|
|
|
|
|
if h2 == h3 {
|
|
|
|
|
|
t.Fatalf("expected different hashes for different keys")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4) empty when no signals
|
2026-02-10 08:59:30 +08:00
|
|
|
|
h4 := svc.GenerateSessionHash(c, []byte(`{}`))
|
2026-01-17 02:31:16 +08:00
|
|
|
|
if h4 != "" {
|
|
|
|
|
|
t.Fatalf("expected empty hash when no signals")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-28 15:01:20 +08:00
|
|
|
|
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
|
|
|
|
|
|
|
|
|
|
c.Request.Header.Set("session_id", "sess-fixed-value")
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
|
|
|
|
|
got := svc.GenerateSessionHash(c, nil)
|
|
|
|
|
|
want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
|
|
|
|
|
|
require.Equal(t, want, got)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
|
|
|
|
|
|
|
|
|
|
c.Request.Header.Set("session_id", "sess-legacy-check")
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
|
|
|
|
|
sessionHash := svc.GenerateSessionHash(c, nil)
|
|
|
|
|
|
require.NotEmpty(t, sessionHash)
|
|
|
|
|
|
require.NotNil(t, c.Request)
|
|
|
|
|
|
require.NotNil(t, c.Request.Context())
|
|
|
|
|
|
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
seed := "openai_ws_ingress:9:100:200"
|
|
|
|
|
|
|
|
|
|
|
|
got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
|
|
|
|
|
|
want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
|
|
|
|
|
|
require.Equal(t, want, got)
|
|
|
|
|
|
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
|
|
|
|
|
|
|
|
|
|
|
empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
|
|
|
|
|
|
require.Equal(t, "", empty)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-20 11:19:32 +08:00
|
|
|
|
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
|
|
|
|
|
if c.waitCounts != nil {
|
|
|
|
|
|
if count, ok := c.waitCounts[accountID]; ok {
|
|
|
|
|
|
return count, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return 0, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type stubGatewayCache struct {
|
|
|
|
|
|
sessionBindings map[string]int64
|
|
|
|
|
|
deletedSessions map[string]int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
|
|
|
|
|
if id, ok := c.sessionBindings[sessionHash]; ok {
|
|
|
|
|
|
return id, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return 0, errors.New("not found")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
|
|
|
|
|
if c.sessionBindings == nil {
|
|
|
|
|
|
c.sessionBindings = make(map[string]int64)
|
|
|
|
|
|
}
|
|
|
|
|
|
c.sessionBindings[sessionHash] = accountID
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
|
|
|
|
|
if c.sessionBindings == nil {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if c.deletedSessions == nil {
|
|
|
|
|
|
c.deletedSessions = make(map[string]int)
|
|
|
|
|
|
}
|
|
|
|
|
|
c.deletedSessions[sessionHash]++
|
|
|
|
|
|
delete(c.sessionBindings, sessionHash)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 22:49:26 -08:00
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
|
resetAt := now.Add(10 * time.Minute)
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
|
|
|
|
|
|
rateLimited := Account{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
|
Schedulable: true,
|
|
|
|
|
|
Concurrency: 1,
|
|
|
|
|
|
Priority: 0,
|
|
|
|
|
|
RateLimitResetAt: &resetAt,
|
|
|
|
|
|
}
|
|
|
|
|
|
available := Account{
|
|
|
|
|
|
ID: 2,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
|
Schedulable: true,
|
|
|
|
|
|
Concurrency: 1,
|
|
|
|
|
|
Priority: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil {
|
|
|
|
|
|
t.Fatalf("expected selection with account")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.Account.ID != available.ID {
|
|
|
|
|
|
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.ReleaseFunc != nil {
|
|
|
|
|
|
selection.ReleaseFunc()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
|
resetAt := now.Add(10 * time.Minute)
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
|
|
|
|
|
|
rateLimited := Account{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
|
Schedulable: true,
|
|
|
|
|
|
Concurrency: 1,
|
|
|
|
|
|
Priority: 0,
|
|
|
|
|
|
RateLimitResetAt: &resetAt,
|
|
|
|
|
|
}
|
|
|
|
|
|
available := Account{
|
|
|
|
|
|
ID: 2,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
|
Schedulable: true,
|
|
|
|
|
|
Concurrency: 1,
|
|
|
|
|
|
Priority: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
|
|
|
|
|
// concurrencyService is nil, forcing the non-load-batch selection path.
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil {
|
|
|
|
|
|
t.Fatalf("expected selection with account")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.Account.ID != available.ID {
|
|
|
|
|
|
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.ReleaseFunc != nil {
|
|
|
|
|
|
selection.ReleaseFunc()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-20 11:19:32 +08:00
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
|
|
|
|
|
sessionHash := "session-1"
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{
|
|
|
|
|
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc == nil || acc.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2, got %+v", acc)
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
|
|
|
|
|
t.Fatalf("expected sticky session to be deleted")
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
|
|
|
|
|
t.Fatalf("expected sticky session to bind to account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
|
|
|
|
|
sessionHash := "session-2"
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{
|
|
|
|
|
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2, got %+v", selection)
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
|
|
|
|
|
t.Fatalf("expected sticky session to be deleted")
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
|
|
|
|
|
t.Fatalf("expected sticky session to bind to account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.ReleaseFunc != nil {
|
|
|
|
|
|
selection.ReleaseFunc()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
|
Schedulable: true,
|
|
|
|
|
|
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
t.Fatalf("expected error for unsupported model")
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc != nil {
|
|
|
|
|
|
t.Fatalf("expected nil account for unsupported model")
|
|
|
|
|
|
}
|
|
|
|
|
|
if !strings.Contains(err.Error(), "supporting model") {
|
|
|
|
|
|
t.Fatalf("unexpected error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadBatchErr: errors.New("load batch failed"),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil {
|
|
|
|
|
|
t.Fatalf("expected selection")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.Account.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.sessionBindings["openai:fallback"] != 2 {
|
|
|
|
|
|
t.Fatalf("expected sticky session updated")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.ReleaseFunc != nil {
|
|
|
|
|
|
selection.ReleaseFunc()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
acquireResults: map[int64]bool{1: false},
|
|
|
|
|
|
loadMap: map[int64]*AccountLoadInfo{
|
|
|
|
|
|
1: {AccountID: 1, LoadRate: 10},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.WaitPlan == nil {
|
|
|
|
|
|
t.Fatalf("expected wait plan fallback")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.Account == nil || selection.Account.ID != 1 {
|
|
|
|
|
|
t.Fatalf("expected account 1")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
|
|
|
|
|
sessionHash := "bind"
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc == nil || acc.ID != 1 {
|
|
|
|
|
|
t.Fatalf("expected account 1")
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
|
|
|
|
|
t.Fatalf("expected sticky session binding")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
|
|
|
|
|
sessionHash := "sticky-wait"
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{
|
|
|
|
|
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
|
|
|
|
|
}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
acquireResults: map[int64]bool{1: false},
|
|
|
|
|
|
waitCounts: map[int64]int{1: 0},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.WaitPlan == nil {
|
|
|
|
|
|
t.Fatalf("expected sticky wait plan")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection.Account == nil || selection.Account.ID != 1 {
|
|
|
|
|
|
t.Fatalf("expected account 1")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadMap: map[int64]*AccountLoadInfo{
|
|
|
|
|
|
1: {AccountID: 1, LoadRate: 80},
|
|
|
|
|
|
2: {AccountID: 2, LoadRate: 10},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
if cache.sessionBindings["openai:load"] != 2 {
|
|
|
|
|
|
t.Fatalf("expected sticky session updated")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
|
|
|
|
|
sessionHash := "excluded"
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{
|
|
|
|
|
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
excluded := map[int64]struct{}{1: {}}
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc == nil || acc.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
|
|
|
|
|
sessionHash := "non-openai"
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{
|
|
|
|
|
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc == nil || acc.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
t.Fatalf("expected error for no accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc != nil {
|
|
|
|
|
|
t.Fatalf("expected nil account")
|
|
|
|
|
|
}
|
|
|
|
|
|
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
|
|
|
|
|
t.Fatalf("unexpected error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
resetAt := time.Now().Add(1 * time.Hour)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
t.Fatalf("expected error for no candidates")
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection != nil {
|
|
|
|
|
|
t.Fatalf("expected nil selection")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadMap: map[int64]*AccountLoadInfo{
|
|
|
|
|
|
1: {AccountID: 1, LoadRate: 100},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.WaitPlan == nil {
|
|
|
|
|
|
t.Fatalf("expected wait plan")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadBatchErr: errors.New("load batch failed"),
|
|
|
|
|
|
acquireResults: map[int64]bool{1: false},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.WaitPlan == nil {
|
|
|
|
|
|
t.Fatalf("expected wait plan")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadMap: map[int64]*AccountLoadInfo{
|
|
|
|
|
|
1: {AccountID: 1, LoadRate: 50},
|
|
|
|
|
|
},
|
|
|
|
|
|
skipDefaultLoad: true,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
|
|
|
|
|
oldTime := time.Now().Add(-2 * time.Hour)
|
|
|
|
|
|
newTime := time.Now().Add(-1 * time.Hour)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc == nil || acc.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
lastUsed := time.Now().Add(-1 * time.Hour)
|
|
|
|
|
|
repo := stubOpenAIAccountRepo{
|
|
|
|
|
|
accounts: []Account{
|
|
|
|
|
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
|
|
|
|
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
cache := &stubGatewayCache{}
|
|
|
|
|
|
concurrencyCache := stubConcurrencyCache{
|
|
|
|
|
|
loadMap: map[int64]*AccountLoadInfo{
|
|
|
|
|
|
1: {AccountID: 1, LoadRate: 10},
|
|
|
|
|
|
2: {AccountID: 2, LoadRate: 10},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{
|
|
|
|
|
|
accountRepo: repo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
|
|
|
|
|
t.Fatalf("expected account 2")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
func TestOpenAIStreamingTimeout(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 1,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: defaultMaxLineSize,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
pr, pw := io.Pipe()
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: pr,
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
start := time.Now()
|
|
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
|
|
|
|
|
_ = pw.Close()
|
|
|
|
|
|
_ = pr.Close()
|
|
|
|
|
|
|
|
|
|
|
|
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
|
|
|
|
|
t.Fatalf("expected stream timeout error, got %v", err)
|
|
|
|
|
|
}
|
2026-01-19 13:53:39 +08:00
|
|
|
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
|
|
|
|
|
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
2026-01-04 19:49:59 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 21:42:13 +08:00
|
|
|
|
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 0,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: defaultMaxLineSize,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
|
cancel()
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
|
|
|
|
|
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: cancelReadCloser{},
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
|
|
|
|
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 21:51:14 +08:00
|
|
|
|
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 0,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: defaultMaxLineSize,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
|
|
|
|
|
|
|
|
|
|
|
pr, pw := io.Pipe()
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: pr,
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer func() { _ = pw.Close() }()
|
|
|
|
|
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
|
|
|
|
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
|
|
|
|
_ = pr.Close()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if result == nil || result.usage == nil {
|
|
|
|
|
|
t.Fatalf("expected usage result")
|
|
|
|
|
|
}
|
|
|
|
|
|
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
|
|
|
|
|
t.Fatalf("unexpected usage: %+v", *result.usage)
|
|
|
|
|
|
}
|
|
|
|
|
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
|
|
|
|
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
func TestOpenAIStreamingTooLong(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 0,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: 64 * 1024,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
pr, pw := io.Pipe()
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: pr,
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
2026-01-04 22:10:32 +08:00
|
|
|
|
defer func() { _ = pw.Close() }()
|
2026-01-04 19:49:59 +08:00
|
|
|
|
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
|
|
|
|
|
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
|
|
|
|
|
_, _ = pw.Write([]byte(payload))
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
|
|
|
|
|
_ = pr.Close()
|
|
|
|
|
|
|
|
|
|
|
|
if !errors.Is(err, bufio.ErrTooLong) {
|
|
|
|
|
|
t.Fatalf("expected ErrTooLong, got %v", err)
|
|
|
|
|
|
}
|
2026-01-19 13:53:39 +08:00
|
|
|
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
|
|
|
|
|
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
2026-01-04 19:49:59 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-05 13:54:43 +08:00
|
|
|
|
|
|
|
|
|
|
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
|
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
|
|
|
|
|
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
|
|
|
|
|
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 0,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: defaultMaxLineSize,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
pr, pw := io.Pipe()
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: pr,
|
|
|
|
|
|
Header: http.Header{
|
|
|
|
|
|
"Cache-Control": []string{"upstream"},
|
2026-01-05 14:41:08 +08:00
|
|
|
|
"X-Request-Id": []string{"req-123"},
|
2026-01-05 13:54:43 +08:00
|
|
|
|
"Content-Type": []string{"application/custom"},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer func() { _ = pw.Close() }()
|
|
|
|
|
|
_, _ = pw.Write([]byte("data: {}\n\n"))
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
|
|
|
|
_ = pr.Close()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("handleStreamingResponse error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if rec.Header().Get("Cache-Control") != "no-cache" {
|
|
|
|
|
|
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
|
|
|
|
|
}
|
|
|
|
|
|
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
|
|
|
|
|
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
|
|
|
|
|
}
|
2026-01-05 14:41:08 +08:00
|
|
|
|
if rec.Header().Get("X-Request-Id") != "req-123" {
|
|
|
|
|
|
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
|
2026-01-05 13:54:43 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-06 22:55:12 +08:00
|
|
|
|
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Gateway: config.GatewayConfig{
|
|
|
|
|
|
StreamDataIntervalTimeout: 0,
|
|
|
|
|
|
StreamKeepaliveInterval: 0,
|
|
|
|
|
|
MaxLineSize: defaultMaxLineSize,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
pr, pw := io.Pipe()
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Body: pr,
|
|
|
|
|
|
Header: http.Header{},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
defer func() { _ = pw.Close() }()
|
|
|
|
|
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
|
|
|
|
_ = pr.Close()
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, result)
|
|
|
|
|
|
require.NotNil(t, result.usage)
|
|
|
|
|
|
require.Equal(t, 1, result.usage.InputTokens)
|
|
|
|
|
|
require.Equal(t, 2, result.usage.OutputTokens)
|
|
|
|
|
|
require.Equal(t, 3, result.usage.CacheReadInputTokens)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-05 13:54:43 +08:00
|
|
|
|
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
account := &Account{
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Credentials: map[string]any{"base_url": "://invalid-url"},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-10 20:53:16 +08:00
|
|
|
|
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
2026-01-05 13:54:43 +08:00
|
|
|
|
if err == nil {
|
|
|
|
|
|
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-05 14:41:08 +08:00
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
|
2026-01-05 13:54:43 +08:00
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
2026-01-05 14:41:08 +08:00
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
|
|
|
|
|
|
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
|
|
|
|
|
|
}
|
|
|
|
|
|
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if normalized != "https://example.com" {
|
|
|
|
|
|
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
URLAllowlist: config.URLAllowlistConfig{
|
|
|
|
|
|
Enabled: false,
|
|
|
|
|
|
AllowInsecureHTTP: true,
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
2026-01-05 13:54:43 +08:00
|
|
|
|
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
|
|
|
|
|
if err != nil {
|
2026-01-05 14:41:08 +08:00
|
|
|
|
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
|
2026-01-05 13:54:43 +08:00
|
|
|
|
}
|
|
|
|
|
|
if normalized != "http://not-https.example.com" {
|
|
|
|
|
|
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
|
|
|
|
|
cfg := &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
URLAllowlist: config.URLAllowlistConfig{
|
|
|
|
|
|
Enabled: true,
|
|
|
|
|
|
UpstreamHosts: []string{"example.com"},
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
|
|
|
|
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
|
|
|
|
|
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
|
|
|
|
|
t.Fatalf("expected non-allowlisted host to fail")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-02-07 17:09:55 +08:00
|
|
|
|
|
2026-03-06 20:46:10 +08:00
|
|
|
|
func TestOpenAIUpdateCodexUsageSnapshotFromHeaders(t *testing.T) {
|
|
|
|
|
|
repo := &snapshotUpdateAccountRepo{updateExtraCalls: make(chan map[string]any, 1)}
|
|
|
|
|
|
svc := &OpenAIGatewayService{accountRepo: repo}
|
|
|
|
|
|
headers := http.Header{}
|
|
|
|
|
|
headers.Set("x-codex-primary-used-percent", "12")
|
|
|
|
|
|
headers.Set("x-codex-secondary-used-percent", "34")
|
|
|
|
|
|
headers.Set("x-codex-primary-window-minutes", "300")
|
|
|
|
|
|
headers.Set("x-codex-secondary-window-minutes", "10080")
|
|
|
|
|
|
headers.Set("x-codex-primary-reset-after-seconds", "600")
|
|
|
|
|
|
headers.Set("x-codex-secondary-reset-after-seconds", "86400")
|
|
|
|
|
|
|
|
|
|
|
|
svc.UpdateCodexUsageSnapshotFromHeaders(context.Background(), 123, headers)
|
|
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
case updates := <-repo.updateExtraCalls:
|
|
|
|
|
|
require.Equal(t, 12.0, updates["codex_5h_used_percent"])
|
|
|
|
|
|
require.Equal(t, 34.0, updates["codex_7d_used_percent"])
|
|
|
|
|
|
require.Equal(t, 600, updates["codex_5h_reset_after_seconds"])
|
|
|
|
|
|
require.Equal(t, 86400, updates["codex_7d_reset_after_seconds"])
|
|
|
|
|
|
case <-time.After(2 * time.Second):
|
|
|
|
|
|
t.Fatal("expected UpdateExtra to be called")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-06 18:50:28 +08:00
|
|
|
|
func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
path string
|
|
|
|
|
|
want string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{name: "exact v1 responses", path: "/v1/responses", want: ""},
|
|
|
|
|
|
{name: "compact v1 responses", path: "/v1/responses/compact", want: "/compact"},
|
|
|
|
|
|
{name: "compact alias responses", path: "/responses/compact/", want: "/compact"},
|
|
|
|
|
|
{name: "nested suffix", path: "/openai/v1/responses/compact/detail", want: "/compact/detail"},
|
|
|
|
|
|
{name: "unrelated path", path: "/v1/chat/completions", want: ""},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
|
|
|
|
|
require.Equal(t, tt.want, openAIResponsesRequestPathSuffix(c))
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
account := &Account{Type: AccountTypeOAuth}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := svc.buildUpstreamRequestOpenAIPassthrough(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token")
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.Equal(t, chatgptCodexURL+"/compact", req.URL.String())
|
|
|
|
|
|
require.Equal(t, "application/json", req.Header.Get("Accept"))
|
|
|
|
|
|
require.Equal(t, codexCLIVersion, req.Header.Get("Version"))
|
|
|
|
|
|
require.NotEmpty(t, req.Header.Get("Session_Id"))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
account := &Account{
|
|
|
|
|
|
Type: AccountTypeOAuth,
|
|
|
|
|
|
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", true)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.Equal(t, chatgptCodexURL+"/compact", req.URL.String())
|
|
|
|
|
|
require.Equal(t, "application/json", req.Header.Get("Accept"))
|
|
|
|
|
|
require.Equal(t, codexCLIVersion, req.Header.Get("Version"))
|
|
|
|
|
|
require.NotEmpty(t, req.Header.Get("Session_Id"))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: &config.Config{
|
|
|
|
|
|
Security: config.SecurityConfig{
|
|
|
|
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
|
|
|
|
},
|
|
|
|
|
|
}}
|
|
|
|
|
|
account := &Account{
|
|
|
|
|
|
Type: AccountTypeAPIKey,
|
|
|
|
|
|
Platform: PlatformOpenAI,
|
|
|
|
|
|
Credentials: map[string]any{"base_url": "https://example.com/v1"},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", false)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String())
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-07 14:12:38 +08:00
|
|
|
|
func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
userAgent string
|
|
|
|
|
|
originator string
|
|
|
|
|
|
wantOriginator string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
|
|
|
|
|
|
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
|
|
|
|
|
|
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
|
|
|
|
|
if tt.userAgent != "" {
|
|
|
|
|
|
c.Request.Header.Set("User-Agent", tt.userAgent)
|
|
|
|
|
|
}
|
|
|
|
|
|
if tt.originator != "" {
|
|
|
|
|
|
c.Request.Header.Set("originator", tt.originator)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
account := &Account{
|
|
|
|
|
|
Type: AccountTypeOAuth,
|
|
|
|
|
|
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
|
|
|
|
|
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.Equal(t, tt.wantOriginator, req.Header.Get("originator"))
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-07 17:09:55 +08:00
|
|
|
|
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
|
|
|
|
|
|
2026-03-06 20:46:10 +08:00
|
|
|
|
// ==================== P1-08 修复:model 替换性能优化测试 =============
|
2026-02-07 17:09:55 +08:00
|
|
|
|
func TestReplaceModelInSSELine(t *testing.T) {
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
line string
|
|
|
|
|
|
from string
|
|
|
|
|
|
to string
|
|
|
|
|
|
expected string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "顶层 model 字段替换",
|
|
|
|
|
|
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-custom-model",
|
|
|
|
|
|
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "嵌套 response.model 替换",
|
|
|
|
|
|
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "model 不匹配时不替换",
|
|
|
|
|
|
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "无 model 字段时不替换",
|
|
|
|
|
|
line: `data: {"id":"chatcmpl-123","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "空 data 行",
|
|
|
|
|
|
line: `data: `,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: `,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "[DONE] 行",
|
|
|
|
|
|
line: `data: [DONE]`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: [DONE]`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "非 data: 前缀行",
|
|
|
|
|
|
line: `event: message`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `event: message`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "非法 JSON 不替换",
|
|
|
|
|
|
line: `data: {invalid json}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: {invalid json}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "无空格 data: 格式",
|
|
|
|
|
|
line: `data:{"id":"x","model":"gpt-4o"}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: `data: {"id":"x","model":"my-model"}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "model 名含特殊字符",
|
|
|
|
|
|
line: `data: {"model":"org/model-v2.1-beta"}`,
|
|
|
|
|
|
from: "org/model-v2.1-beta",
|
|
|
|
|
|
to: "custom/alias",
|
|
|
|
|
|
expected: `data: {"model":"custom/alias"}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "空行",
|
|
|
|
|
|
line: "",
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "my-model",
|
|
|
|
|
|
expected: "",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "保持其他字段不变",
|
|
|
|
|
|
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "顶层优先于嵌套:同时存在两个 model",
|
|
|
|
|
|
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "replaced",
|
|
|
|
|
|
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
|
|
|
|
|
|
require.Equal(t, tt.expected, got)
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestReplaceModelInSSEBody(t *testing.T) {
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
body string
|
|
|
|
|
|
from string
|
|
|
|
|
|
to string
|
|
|
|
|
|
expected string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "多行 SSE body 替换",
|
|
|
|
|
|
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "无需替换的 body",
|
|
|
|
|
|
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "混合 event 和 data 行",
|
|
|
|
|
|
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "空 body",
|
|
|
|
|
|
body: "",
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: "",
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
|
|
|
|
|
|
require.Equal(t, tt.expected, got)
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestReplaceModelInResponseBody(t *testing.T) {
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
body string
|
|
|
|
|
|
from string
|
|
|
|
|
|
to string
|
|
|
|
|
|
expected string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "替换顶层 model",
|
|
|
|
|
|
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "model 不匹配不替换",
|
|
|
|
|
|
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "无 model 字段不替换",
|
|
|
|
|
|
body: `{"id":"chatcmpl-123","choices":[]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `{"id":"chatcmpl-123","choices":[]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "非法 JSON 返回原值",
|
|
|
|
|
|
body: `not json`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `not json`,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "空 body 返回原值",
|
|
|
|
|
|
body: ``,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: ``,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "保持嵌套结构不变",
|
|
|
|
|
|
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
|
|
|
|
|
from: "gpt-4o",
|
|
|
|
|
|
to: "alias",
|
|
|
|
|
|
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
|
|
|
|
|
|
require.Equal(t, tt.expected, string(got))
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-02-12 09:41:37 +08:00
|
|
|
|
|
|
|
|
|
|
func TestExtractOpenAISSEDataLine(t *testing.T) {
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
line string
|
|
|
|
|
|
wantData string
|
|
|
|
|
|
wantOK bool
|
|
|
|
|
|
}{
|
|
|
|
|
|
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
|
|
|
|
|
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
|
|
|
|
|
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
|
|
|
|
|
|
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
got, ok := extractOpenAISSEDataLine(tt.line)
|
|
|
|
|
|
require.Equal(t, tt.wantOK, ok)
|
|
|
|
|
|
require.Equal(t, tt.wantData, got)
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
|
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
|
|
|
|
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
|
|
|
|
|
|
|
|
|
|
|
|
// 非 completed 事件,不应覆盖 usage
|
|
|
|
|
|
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
|
|
|
|
|
|
require.Equal(t, 9, usage.InputTokens)
|
|
|
|
|
|
require.Equal(t, 8, usage.OutputTokens)
|
|
|
|
|
|
require.Equal(t, 7, usage.CacheReadInputTokens)
|
|
|
|
|
|
|
|
|
|
|
|
// completed 事件,应提取 usage
|
|
|
|
|
|
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
|
|
|
|
|
|
require.Equal(t, 3, usage.InputTokens)
|
|
|
|
|
|
require.Equal(t, 5, usage.OutputTokens)
|
|
|
|
|
|
require.Equal(t, 2, usage.CacheReadInputTokens)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
|
|
|
|
|
body := strings.Join([]string{
|
|
|
|
|
|
`event: message`,
|
|
|
|
|
|
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
|
|
|
|
|
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
|
|
|
|
|
|
`data: [DONE]`,
|
|
|
|
|
|
}, "\n")
|
|
|
|
|
|
|
|
|
|
|
|
finalResp, ok := extractCodexFinalResponse(body)
|
|
|
|
|
|
require.True(t, ok)
|
|
|
|
|
|
require.Contains(t, string(finalResp), `"id":"resp_1"`)
|
|
|
|
|
|
require.Contains(t, string(finalResp), `"input_tokens":11`)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
|
|
|
|
}
|
|
|
|
|
|
body := []byte(strings.Join([]string{
|
|
|
|
|
|
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
|
|
|
|
|
|
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
|
|
|
|
|
|
`data: [DONE]`,
|
|
|
|
|
|
}, "\n"))
|
|
|
|
|
|
|
|
|
|
|
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, usage)
|
|
|
|
|
|
require.Equal(t, 7, usage.InputTokens)
|
|
|
|
|
|
require.Equal(t, 9, usage.OutputTokens)
|
|
|
|
|
|
require.Equal(t, 1, usage.CacheReadInputTokens)
|
|
|
|
|
|
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
|
|
|
|
|
|
require.NotContains(t, rec.Body.String(), "event:")
|
|
|
|
|
|
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
|
|
|
|
|
|
require.NotContains(t, rec.Body.String(), "data:")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
|
|
|
|
}
|
|
|
|
|
|
body := []byte(strings.Join([]string{
|
|
|
|
|
|
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
|
|
|
|
|
|
`data: [DONE]`,
|
|
|
|
|
|
}, "\n"))
|
|
|
|
|
|
|
|
|
|
|
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, usage)
|
|
|
|
|
|
require.Equal(t, 0, usage.InputTokens)
|
|
|
|
|
|
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
|
|
|
|
|
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
|
|
|
|
|
}
|
2026-03-06 14:54:52 +08:00
|
|
|
|
|
|
|
|
|
|
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
|
|
|
|
|
|
|
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
|
|
|
|
|
resp := &http.Response{
|
|
|
|
|
|
StatusCode: http.StatusOK,
|
|
|
|
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
|
|
|
|
}
|
|
|
|
|
|
body := []byte(strings.Join([]string{
|
|
|
|
|
|
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`,
|
|
|
|
|
|
`data: [DONE]`,
|
|
|
|
|
|
}, "\n"))
|
|
|
|
|
|
|
|
|
|
|
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
|
|
|
|
|
require.Nil(t, usage)
|
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
|
require.Equal(t, http.StatusBadGateway, rec.Code)
|
|
|
|
|
|
require.Contains(t, rec.Body.String(), "upstream rejected request")
|
|
|
|
|
|
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
|
|
|
|
|
|
}
|