mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 17:00:20 +08:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa |
@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
// Gemini 2.5 白名单
|
// Gemini 2.5 白名单
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
// Gemini 3 白名单
|
// Gemini 3 白名单
|
||||||
"gemini-3-flash": "gemini-3-flash",
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
|||||||
@@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
|||||||
// TestAccountRequest represents the request body for testing an account
|
// TestAccountRequest represents the request body for testing an account
|
||||||
type TestAccountRequest struct {
|
type TestAccountRequest struct {
|
||||||
ModelID string `json:"model_id"`
|
ModelID string `json:"model_id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSRequest struct {
|
type SyncFromCRSRequest struct {
|
||||||
@@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
|||||||
_ = c.ShouldBindJSON(&req)
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
// Use AccountTestService to test the account with SSE streaming
|
// Use AccountTestService to test the account with SSE streaming
|
||||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||||
// Error already sent via SSE, just log
|
// Error already sent via SSE, just log
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"models": stats,
|
"models": stats,
|
||||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
response.Error(c, 500, "Failed to get group statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"groups": stats,
|
"groups": stats,
|
||||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
|||||||
limit = 5
|
limit = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage trend")
|
response.Error(c, 500, "Failed to get API key usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
|||||||
limit = 12
|
limit = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
response.Error(c, 500, "Failed to get user usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
|
|||||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardUsageRepoCacheProbe struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
trendCalls atomic.Int32
|
||||||
|
usersTrendCalls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
r.trendCalls.Add(1)
|
||||||
|
return []usagestats.TrendDataPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
Requests: 1,
|
||||||
|
TotalTokens: 2,
|
||||||
|
Cost: 3,
|
||||||
|
ActualCost: 4,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
limit int,
|
||||||
|
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
|
r.usersTrendCalls.Add(1)
|
||||||
|
return []usagestats.UserUsageTrendPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
UserID: 1,
|
||||||
|
Email: "cache@test.dev",
|
||||||
|
Requests: 2,
|
||||||
|
Tokens: 20,
|
||||||
|
Cost: 2,
|
||||||
|
ActualCost: 1,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetDashboardReadCachesForTest() {
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||||
|
}
|
||||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardModelGroupCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardEntityTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheStatusValue(hit bool) string {
|
||||||
|
if hit {
|
||||||
|
return "hit"
|
||||||
|
}
|
||||||
|
return "miss"
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalDashboardCacheKey(value any) string {
|
||||||
|
raw, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||||
|
typed, ok := payload.(T)
|
||||||
|
if !ok {
|
||||||
|
var zero T
|
||||||
|
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||||
|
}
|
||||||
|
return typed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUsageTrendCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getModelStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.ModelStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getGroupStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.GroupStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
cacheKey := string(keyRaw)
|
cacheKey := string(keyRaw)
|
||||||
|
|
||||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||||
if cached.ETag != "" {
|
return h.buildSnapshotV2Response(
|
||||||
c.Header("ETag", cached.ETag)
|
c.Request.Context(),
|
||||||
c.Header("Vary", "If-None-Match")
|
startTime,
|
||||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
endTime,
|
||||||
c.Status(http.StatusNotModified)
|
granularity,
|
||||||
return
|
filters,
|
||||||
}
|
includeStats,
|
||||||
}
|
includeTrend,
|
||||||
c.Header("X-Snapshot-Cache", "hit")
|
includeModels,
|
||||||
response.Success(c, cached.Payload)
|
includeGroups,
|
||||||
|
includeUsersTrend,
|
||||||
|
usersTrendLimit,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
filters *dashboardSnapshotV2Filters,
|
||||||
|
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||||
|
usersTrendLimit int,
|
||||||
|
) (*dashboardSnapshotV2Response, error) {
|
||||||
resp := &dashboardSnapshotV2Response{
|
resp := &dashboardSnapshotV2Response{
|
||||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
StartDate: startTime.Format("2006-01-02"),
|
StartDate: startTime.Format("2006-01-02"),
|
||||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeStats {
|
if includeStats {
|
||||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
return nil, errors.New("failed to get dashboard statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Stats = &dashboardSnapshotV2Stats{
|
resp.Stats = &dashboardSnapshotV2Stats{
|
||||||
DashboardStats: *stats,
|
DashboardStats: *stats,
|
||||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeTrend {
|
if includeTrend {
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
trend, _, err := h.getUsageTrendCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
granularity,
|
granularity,
|
||||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
return nil, errors.New("failed to get usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Trend = trend
|
resp.Trend = trend
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeModels {
|
if includeModels {
|
||||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
models, _, err := h.getModelStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
return nil, errors.New("failed to get model statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Models = models
|
resp.Models = models
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeGroups {
|
if includeGroups {
|
||||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
groups, _, err := h.getGroupStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
return nil, errors.New("failed to get group statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Groups = groups
|
resp.Groups = groups
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeUsersTrend {
|
if includeUsersTrend {
|
||||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||||
c.Request.Context(),
|
|
||||||
startTime,
|
|
||||||
endTime,
|
|
||||||
granularity,
|
|
||||||
usersTrendLimit,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
return nil, errors.New("failed to get user usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.UsersTrend = usersTrend
|
resp.UsersTrend = usersTrend
|
||||||
}
|
}
|
||||||
|
|
||||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
return resp, nil
|
||||||
if cached.ETag != "" {
|
|
||||||
c.Header("ETag", cached.ETag)
|
|
||||||
c.Header("Vary", "If-None-Match")
|
|
||||||
}
|
|
||||||
c.Header("X-Snapshot-Cache", "miss")
|
|
||||||
response.Success(c, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||||
|
|||||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
|||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent",
|
"memory_usage_percent",
|
||||||
"concurrency_queue_depth",
|
"concurrency_queue_depth",
|
||||||
|
"group_available_accounts",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_rate_limited_count",
|
||||||
|
"account_error_count",
|
||||||
|
"account_error_ratio",
|
||||||
|
"overload_account_count",
|
||||||
}
|
}
|
||||||
|
|
||||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
|||||||
"error_rate",
|
"error_rate",
|
||||||
"upstream_error_rate",
|
"upstream_error_rate",
|
||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent":
|
"memory_usage_percent",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_error_ratio":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
type snapshotCacheEntry struct {
|
type snapshotCacheEntry struct {
|
||||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
items map[string]snapshotCacheEntry
|
items map[string]snapshotCacheEntry
|
||||||
|
sf singleflight.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshotCacheLoadResult struct {
|
||||||
|
Entry snapshotCacheEntry
|
||||||
|
Hit bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
|||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||||
|
if load == nil {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return entry, true, nil
|
||||||
|
}
|
||||||
|
if c == nil || key == "" {
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
return c.Set(key, payload), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||||
|
}
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
result, ok := value.(snapshotCacheLoadResult)
|
||||||
|
if !ok {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
return result.Entry, result.Hit, nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildETagFromAny(payload any) string {
|
func buildETagFromAny(payload any) string {
|
||||||
raw, err := json.Marshal(payload)
|
raw, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
|||||||
require.Empty(t, etag)
|
require.Empty(t, etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
|
||||||
|
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"hello": "world"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, hit)
|
||||||
|
require.NotEmpty(t, entry.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
|
||||||
|
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"unexpected": "value"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hit)
|
||||||
|
require.Equal(t, entry.ETag, entry2.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
start := make(chan struct{})
|
||||||
|
const callers = 8
|
||||||
|
errCh := make(chan error, callers)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(callers)
|
||||||
|
for range callers {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
return "value", nil
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
for err := range errCh {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||||
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
|
Daily bool `json:"daily"`
|
||||||
|
Weekly bool `json:"weekly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||||
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid subscription ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req ResetSubscriptionQuotaRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.Daily && !req.Weekly {
|
||||||
|
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||||
|
}
|
||||||
|
|
||||||
// Revoke handles revoking a subscription
|
// Revoke handles revoking a subscription
|
||||||
// DELETE /api/v1/admin/subscriptions/:id
|
// DELETE /api/v1/admin/subscriptions/:id
|
||||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||||
|
|||||||
290
backend/internal/handler/openai_chat_completions.go
Normal file
290
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||||
|
// POST /v1/chat/completions
|
||||||
|
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||||
|
streamStarted := false
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.chat_completions",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqModel := modelResult.String()
|
||||||
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
|
||||||
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||||
|
|
||||||
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
if h.errorPassthroughService != nil {
|
||||||
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
|
routingStart := time.Now()
|
||||||
|
|
||||||
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||||
|
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||||
|
|
||||||
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
|
switchCount := 0
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int)
|
||||||
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", "")
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
defaultModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if defaultModel != "" && defaultModel != reqModel {
|
||||||
|
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||||
|
zap.String("default_mapped_model", defaultModel),
|
||||||
|
)
|
||||||
|
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
defaultModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err == nil && selection != nil {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||||
|
} else {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account := selection.Account
|
||||||
|
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||||
|
_ = scheduleDecision
|
||||||
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
|
|
||||||
|
defaultMappedModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||||
|
defaultMappedModel = fallbackModel
|
||||||
|
}
|
||||||
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
// Pool mode: retry on the same account
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
retryLimit := account.GetPoolModeRetryCount()
|
||||||
|
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("retry_limit", retryLimit),
|
||||||
|
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||||
|
)
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(sameAccountRetryDelay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
|
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
} else {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
}); err != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.String("model", reqModel),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
reqLog.Debug("openai_chat_completions.request_completed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
const (
|
const (
|
||||||
opsErrorLogTimeout = 5 * time.Second
|
opsErrorLogTimeout = 5 * time.Second
|
||||||
opsErrorLogDrainTimeout = 10 * time.Second
|
opsErrorLogDrainTimeout = 10 * time.Second
|
||||||
|
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||||
|
|
||||||
opsErrorLogMinWorkerCount = 4
|
opsErrorLogMinWorkerCount = 4
|
||||||
opsErrorLogMaxWorkerCount = 32
|
opsErrorLogMaxWorkerCount = 32
|
||||||
@@ -38,6 +39,7 @@ const (
|
|||||||
opsErrorLogQueueSizePerWorker = 128
|
opsErrorLogQueueSizePerWorker = 128
|
||||||
opsErrorLogMinQueueSize = 256
|
opsErrorLogMinQueueSize = 256
|
||||||
opsErrorLogMaxQueueSize = 8192
|
opsErrorLogMaxQueueSize = 8192
|
||||||
|
opsErrorLogBatchSize = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
type opsErrorLogJob struct {
|
type opsErrorLogJob struct {
|
||||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
|||||||
for i := 0; i < workerCount; i++ {
|
for i := 0; i < workerCount; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer opsErrorLogWorkersWg.Done()
|
defer opsErrorLogWorkersWg.Done()
|
||||||
for job := range opsErrorLogQueue {
|
for {
|
||||||
opsErrorLogQueueLen.Add(-1)
|
job, ok := <-opsErrorLogQueue
|
||||||
if job.ops == nil || job.entry == nil {
|
if !ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
func() {
|
opsErrorLogQueueLen.Add(-1)
|
||||||
defer func() {
|
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||||
if r := recover(); r != nil {
|
batch = append(batch, job)
|
||||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
|
||||||
|
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||||
|
batchLoop:
|
||||||
|
for len(batch) < opsErrorLogBatchSize {
|
||||||
|
select {
|
||||||
|
case nextJob, ok := <-opsErrorLogQueue:
|
||||||
|
if !ok {
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
opsErrorLogQueueLen.Add(-1)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
batch = append(batch, nextJob)
|
||||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
case <-timer.C:
|
||||||
cancel()
|
break batchLoop
|
||||||
opsErrorLogProcessed.Add(1)
|
}
|
||||||
}()
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||||
|
var processed int64
|
||||||
|
for _, job := range batch {
|
||||||
|
if job.ops == nil || job.entry == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||||
|
processed++
|
||||||
|
}
|
||||||
|
if processed == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for opsSvc, entries := range grouped {
|
||||||
|
if opsSvc == nil || len(entries) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||||
|
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
opsErrorLogProcessed.Add(processed)
|
||||||
|
}
|
||||||
|
|
||||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||||
if ops == nil || entry == nil {
|
if ops == nil || entry == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
|||||||
// Antigravity 支持的 Gemini 模型
|
// Antigravity 支持的 Gemini 模型
|
||||||
var geminiModels = []modelDef{
|
var geminiModels = []modelDef{
|
||||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
|||||||
|
|
||||||
requiredIDs := []string{
|
requiredIDs := []string{
|
||||||
"claude-opus-4-6-thinking",
|
"claude-opus-4-6-thinking",
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview",
|
||||||
"gemini-3.1-flash-image",
|
"gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview",
|
"gemini-3.1-flash-image-preview",
|
||||||
"gemini-3-pro-image", // legacy compatibility
|
"gemini-3-pro-image", // legacy compatibility
|
||||||
|
|||||||
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,733 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ChatCompletionsToResponses tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4o", resp.Model)
|
||||||
|
assert.True(t, resp.Stream) // always forced true
|
||||||
|
assert.False(t, *resp.Store)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "system", items[0].Role)
|
||||||
|
assert.Equal(t, "user", items[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "ping",
|
||||||
|
Arguments: `{"host":"example.com"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallID: "call_1",
|
||||||
|
Content: json.RawMessage(`"pong"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []ChatTool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: &ChatFunction{
|
||||||
|
Name: "ping",
|
||||||
|
Description: "Ping a host",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + function_call + function_call_output = 3
|
||||||
|
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
|
||||||
|
// Check function_call item
|
||||||
|
assert.Equal(t, "function_call", items[1].Type)
|
||||||
|
assert.Equal(t, "call_1", items[1].CallID)
|
||||||
|
assert.Equal(t, "ping", items[1].Name)
|
||||||
|
|
||||||
|
// Check function_call_output item
|
||||||
|
assert.Equal(t, "function_call_output", items[2].Type)
|
||||||
|
assert.Equal(t, "call_1", items[2].CallID)
|
||||||
|
assert.Equal(t, "pong", items[2].Output)
|
||||||
|
|
||||||
|
// Check tools
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||||
|
t.Run("max_tokens", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
// Below minMaxOutputTokens (128), should be clamped
|
||||||
|
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
maxCompletion := 500
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
MaxCompletionTokens: &maxCompletion,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ReasoningEffort: "high",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.Reasoning)
|
||||||
|
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||||
|
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||||
|
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(content)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||||
|
require.Len(t, parts, 2)
|
||||||
|
assert.Equal(t, "input_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "Describe this", parts[0].Text)
|
||||||
|
assert.Equal(t, "input_image", parts[1].Type)
|
||||||
|
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
Functions: []ChatFunction{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// tool_choice should be converted
|
||||||
|
require.NotNil(t, resp.ToolChoice)
|
||||||
|
var tc map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
|
assert.Equal(t, "function", tc["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ServiceTier: "flex",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "flex", resp.ServiceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: json.RawMessage(`"Let me call a function."`),
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "do_thing",
|
||||||
|
Arguments: `{}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + assistant message (with text) + function_call
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ResponsesToChatCompletions tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_123",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "Hello, world!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
assert.Equal(t, "chat.completion", chat.Object)
|
||||||
|
assert.Equal(t, "gpt-4o", chat.Model)
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "Hello, world!", content)
|
||||||
|
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_456",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_xyz",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"city":"NYC"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
msg := chat.Choices[0].Message
|
||||||
|
require.Len(t, msg.ToolCalls, 1)
|
||||||
|
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||||
|
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||||
|
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_789",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "reasoning",
|
||||||
|
Summary: []ResponsesSummary{
|
||||||
|
{Type: "summary_text", Text: "I thought about it."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "The answer is 42."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
// Reasoning summary is prepended to text
|
||||||
|
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_inc",
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "partial..."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_cache",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 10,
|
||||||
|
TotalTokens: 110,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_ws",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "web_search_call",
|
||||||
|
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "search results", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesEventToChatChunks tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
// response.created → role chunk
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
ID: "resp_stream",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||||
|
assert.True(t, state.SentRole)
|
||||||
|
|
||||||
|
// response.output_text.delta → content chunk
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "Hello",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||||
|
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
assert.Equal(t, "call_1", tc.ID)
|
||||||
|
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index)
|
||||||
|
|
||||||
|
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||||
|
Delta: `{"city":`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||||
|
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||||
|
|
||||||
|
// Add a second function call at output_index=2
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_2",
|
||||||
|
Name: "get_time",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||||
|
|
||||||
|
// Argument delta for second tool call
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Delta: `{"tz":"UTC"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||||
|
|
||||||
|
// Argument delta for first tool call (interleaved)
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Delta: `"Tokyo"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalTokens: 70,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
// finish chunk + usage chunk
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// First chunk: finish_reason
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Second chunk: usage
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||||
|
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SawToolCall = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "Thinking...",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
state.Usage = &ChatUsage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := FinalizeResponsesChatStream(state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||||
|
|
||||||
|
// Idempotent: second call returns nil
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||||
|
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||||
|
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
// Simulate response.completed
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||||
|
|
||||||
|
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatChunkToSSE(t *testing.T) {
|
||||||
|
chunk := ChatCompletionsChunk{
|
||||||
|
ID: "chatcmpl-test",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1700000000,
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Choices: []ChatChunkChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Role: "assistant"},
|
||||||
|
FinishReason: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sse, err := ChatChunkToSSE(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, sse, "data: ")
|
||||||
|
assert.Contains(t, sse, "chatcmpl-test")
|
||||||
|
assert.Contains(t, sse, "assistant")
|
||||||
|
assert.True(t, len(sse) > 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stream round-trip test
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||||
|
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||||
|
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||||
|
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
var allChunks []ChatCompletionsChunk
|
||||||
|
|
||||||
|
// 1. response.created
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// 2. text deltas
|
||||||
|
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: text,
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. response.completed
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 4,
|
||||||
|
TotalTokens: 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||||
|
require.Len(t, allChunks, 7)
|
||||||
|
|
||||||
|
// First chunk has role
|
||||||
|
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||||
|
|
||||||
|
// Text chunks
|
||||||
|
var fullText string
|
||||||
|
for i := 1; i <= 4; i++ {
|
||||||
|
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||||
|
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
assert.Equal(t, "Hello, world!", fullText)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, allChunks[6].Usage)
|
||||||
|
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||||
|
|
||||||
|
// All chunks share the same ID
|
||||||
|
for _, c := range allChunks {
|
||||||
|
assert.Equal(t, "resp_rt", c.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||||
|
// Responses API request. The upstream always streams, so Stream is forced to
|
||||||
|
// true. store is always false and reasoning.encrypted_content is always
|
||||||
|
// included so that the response translator has full context.
|
||||||
|
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||||
|
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputJSON, err := json.Marshal(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ResponsesRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Input: inputJSON,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: true, // upstream always streams
|
||||||
|
Include: []string{"reasoning.encrypted_content"},
|
||||||
|
ServiceTier: req.ServiceTier,
|
||||||
|
}
|
||||||
|
|
||||||
|
storeFalse := false
|
||||||
|
out.Store = &storeFalse
|
||||||
|
|
||||||
|
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||||
|
maxTokens := 0
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
maxTokens = *req.MaxTokens
|
||||||
|
}
|
||||||
|
if req.MaxCompletionTokens != nil {
|
||||||
|
maxTokens = *req.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
if maxTokens > 0 {
|
||||||
|
v := maxTokens
|
||||||
|
if v < minMaxOutputTokens {
|
||||||
|
v = minMaxOutputTokens
|
||||||
|
}
|
||||||
|
out.MaxOutputTokens = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||||
|
if req.ReasoningEffort != "" {
|
||||||
|
out.Reasoning = &ResponsesReasoning{
|
||||||
|
Effort: req.ReasoningEffort,
|
||||||
|
Summary: "auto",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools[] and legacy functions[] → ResponsesTool[]
|
||||||
|
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||||
|
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tool_choice: already compatible format — pass through directly.
|
||||||
|
// Legacy function_call needs mapping.
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
out.ToolChoice = req.ToolChoice
|
||||||
|
} else if len(req.FunctionCall) > 0 {
|
||||||
|
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||||
|
}
|
||||||
|
out.ToolChoice = tc
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||||
|
// array into a Responses API input items array.
|
||||||
|
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var out []ResponsesInputItem
|
||||||
|
for _, m := range msgs {
|
||||||
|
items, err := chatMessageToResponsesItems(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, items...)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||||
|
// ResponsesInputItem values.
|
||||||
|
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
switch m.Role {
|
||||||
|
case "system":
|
||||||
|
return chatSystemToResponses(m)
|
||||||
|
case "user":
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
case "assistant":
|
||||||
|
return chatAssistantToResponses(m)
|
||||||
|
case "tool":
|
||||||
|
return chatToolToResponses(m)
|
||||||
|
case "function":
|
||||||
|
return chatFunctionToResponses(m)
|
||||||
|
default:
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatSystemToResponses converts a system message.
|
||||||
|
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
text, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
content, err := json.Marshal(text)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatUserToResponses converts a user message, handling both plain strings and
|
||||||
|
// multi-modal content arrays.
|
||||||
|
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
// Try plain string first.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||||
|
content, _ := json.Marshal(s)
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []ChatContentPart
|
||||||
|
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse user content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseParts []ResponsesContentPart
|
||||||
|
for _, p := range parts {
|
||||||
|
switch p.Type {
|
||||||
|
case "text":
|
||||||
|
if p.Text != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_text",
|
||||||
|
Text: p.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_image",
|
||||||
|
ImageURL: p.ImageURL.URL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := json.Marshal(responseParts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatAssistantToResponses converts an assistant message. If there is both
|
||||||
|
// text content and tool_calls, the text is emitted as an assistant message
|
||||||
|
// first, then each tool_call becomes a function_call item. If the content is
|
||||||
|
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||||
|
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
|
||||||
|
// Emit assistant message with output_text if content is non-empty.
|
||||||
|
if len(m.Content) > 0 {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||||
|
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||||
|
partsJSON, err := json.Marshal(parts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit one function_call item per tool_call.
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
args := tc.Function.Arguments
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: args,
|
||||||
|
ID: tc.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||||
|
// function_call_output item.
|
||||||
|
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.ToolCallID,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatFunctionToResponses converts a legacy function result message
|
||||||
|
// (role=function) into a function_call_output item. The Name field is used as
|
||||||
|
// call_id since legacy function calls do not carry a separate call_id.
|
||||||
|
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.Name,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||||
|
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||||
|
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err != nil {
|
||||||
|
return "", fmt.Errorf("parse content as string: %w", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||||
|
// function definitions to Responses API tool definitions.
|
||||||
|
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||||
|
var out []ResponsesTool
|
||||||
|
|
||||||
|
for _, t := range tools {
|
||||||
|
if t.Type != "function" || t.Function == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Description: t.Function.Description,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: t.Function.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy functions[] are treated as function-type tools.
|
||||||
|
for _, f := range functions {
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: f.Name,
|
||||||
|
Description: f.Description,
|
||||||
|
Parameters: f.Parameters,
|
||||||
|
Strict: f.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||||
|
// Responses API tool_choice value.
|
||||||
|
//
|
||||||
|
// "auto" → "auto"
|
||||||
|
// "none" → "none"
|
||||||
|
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||||
|
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
|
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Object form: {"name":"X"}
|
||||||
|
var obj struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]string{"name": obj.Name},
|
||||||
|
})
|
||||||
|
}
|
||||||
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||||
|
// Completions response. Text output items are concatenated into
|
||||||
|
// choices[0].message.content; function_call items become tool_calls.
|
||||||
|
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||||
|
id := resp.ID
|
||||||
|
if id == "" {
|
||||||
|
id = generateChatCmplID()
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ChatCompletionsResponse{
|
||||||
|
ID: id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
|
||||||
|
var contentText string
|
||||||
|
var toolCalls []ChatToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, part := range item.Content {
|
||||||
|
if part.Type == "output_text" && part.Text != "" {
|
||||||
|
contentText += part.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
toolCalls = append(toolCalls, ChatToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: item.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case "reasoning":
|
||||||
|
for _, s := range item.Summary {
|
||||||
|
if s.Type == "summary_text" && s.Text != "" {
|
||||||
|
contentText += s.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "web_search_call":
|
||||||
|
// silently consumed — results already incorporated into text output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := ChatMessage{Role: "assistant"}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
msg.ToolCalls = toolCalls
|
||||||
|
}
|
||||||
|
if contentText != "" {
|
||||||
|
raw, _ := json.Marshal(contentText)
|
||||||
|
msg.Content = raw
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||||
|
|
||||||
|
out.Choices = []ChatChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: msg,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
}}
|
||||||
|
|
||||||
|
if resp.Usage != nil {
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: resp.Usage.InputTokens,
|
||||||
|
CompletionTokens: resp.Usage.OutputTokens,
|
||||||
|
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||||
|
switch status {
|
||||||
|
case "incomplete":
|
||||||
|
if details != nil && details.Reason == "max_output_tokens" {
|
||||||
|
return "length"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
case "completed":
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return "tool_calls"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
default:
|
||||||
|
return "stop"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||||
|
// SSE events into Chat Completions SSE chunks.
|
||||||
|
type ResponsesEventToChatState struct {
|
||||||
|
ID string
|
||||||
|
Model string
|
||||||
|
Created int64
|
||||||
|
SentRole bool
|
||||||
|
SawToolCall bool
|
||||||
|
SawText bool
|
||||||
|
Finalized bool // true after finish chunk has been emitted
|
||||||
|
NextToolCallIndex int // next sequential tool_call index to assign
|
||||||
|
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||||
|
IncludeUsage bool
|
||||||
|
Usage *ChatUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponsesEventToChatState returns an initialised stream state.
|
||||||
|
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||||
|
return &ResponsesEventToChatState{
|
||||||
|
ID: generateChatCmplID(),
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OutputIndexToToolIndex: make(map[int]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||||
|
// or more Chat Completions chunks, updating state as it goes.
|
||||||
|
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
switch evt.Type {
|
||||||
|
case "response.created":
|
||||||
|
return resToChatHandleCreated(evt, state)
|
||||||
|
case "response.output_text.delta":
|
||||||
|
return resToChatHandleTextDelta(evt, state)
|
||||||
|
case "response.output_item.added":
|
||||||
|
return resToChatHandleOutputItemAdded(evt, state)
|
||||||
|
case "response.function_call_arguments.delta":
|
||||||
|
return resToChatHandleFuncArgsDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.delta":
|
||||||
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
|
case "response.completed", "response.incomplete", "response.failed":
|
||||||
|
return resToChatHandleCompleted(evt, state)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||||
|
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||||
|
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||||
|
// this returns nil.
|
||||||
|
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if state.Finalized {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.Finalized = true
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||||
|
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||||
|
data, err := json.Marshal(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal handlers ---
|
||||||
|
|
||||||
|
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.ID != "" {
|
||||||
|
state.ID = evt.Response.ID
|
||||||
|
}
|
||||||
|
if state.Model == "" && evt.Response.Model != "" {
|
||||||
|
state.Model = evt.Response.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Emit the role chunk.
|
||||||
|
if state.SentRole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
role := "assistant"
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SawText = true
|
||||||
|
content := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state.SawToolCall = true
|
||||||
|
idx := state.NextToolCallIndex
|
||||||
|
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||||
|
state.NextToolCallIndex++
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
ID: evt.Item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: evt.Item.Name,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Arguments: evt.Delta,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
content := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
state.Finalized = true
|
||||||
|
finishReason := "stop"
|
||||||
|
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.Usage != nil {
|
||||||
|
u := evt.Response.Usage
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: u.InputTokens,
|
||||||
|
CompletionTokens: u.OutputTokens,
|
||||||
|
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
switch evt.Response.Status {
|
||||||
|
case "incomplete":
|
||||||
|
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
case "completed":
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []ChatCompletionsChunk
|
||||||
|
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: delta,
|
||||||
|
FinishReason: nil,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||||
|
empty := ""
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Content: &empty},
|
||||||
|
FinishReason: &finishReason,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||||
|
func generateChatCmplID() string {
|
||||||
|
b := make([]byte, 12)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return "chatcmpl-" + hex.EncodeToString(b)
|
||||||
|
}
|
||||||
@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
|
|||||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// OpenAI Chat Completions API types
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||||
|
Tools []ChatTool `json:"tools,omitempty"`
|
||||||
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||||
|
|
||||||
|
// Legacy function calling (deprecated but still supported)
|
||||||
|
Functions []ChatFunction `json:"functions,omitempty"`
|
||||||
|
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamOptions configures streaming behavior.
|
||||||
|
type ChatStreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage is a single message in the Chat Completions conversation.
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||||
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
|
||||||
|
// Legacy function calling
|
||||||
|
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatContentPart is a typed content part in a multi-modal message.
|
||||||
|
type ChatContentPart struct {
|
||||||
|
Type string `json:"type"` // "text" | "image_url"
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatImageURL contains the URL for an image content part.
|
||||||
|
type ChatImageURL struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTool describes a tool available to the model.
|
||||||
|
type ChatTool struct {
|
||||||
|
Type string `json:"type"` // "function"
|
||||||
|
Function *ChatFunction `json:"function,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunction describes a function tool definition.
|
||||||
|
type ChatFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Strict *bool `json:"strict,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatToolCall represents a tool call made by the assistant.
|
||||||
|
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||||
|
type ChatToolCall struct {
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"` // "function"
|
||||||
|
Function ChatFunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunctionCall contains the function name and arguments.
|
||||||
|
type ChatFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChoice is a single completion choice.
|
||||||
|
type ChatChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatUsage holds token counts in Chat Completions format.
|
||||||
|
type ChatUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTokenDetails provides a breakdown of token usage.
|
||||||
|
type ChatTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsChunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion.chunk"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChunkChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||||
|
type ChatChunkChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta ChatDelta `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatDelta carries incremental content in a streaming chunk.
|
||||||
|
type ChatDelta struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Shared constants
|
// Shared constants
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
|||||||
return []Model{
|
return []Model{
|
||||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
models := DefaultModels()
|
||||||
|
byName := make(map[string]Model, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
byName[model.Name] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"models/gemini-2.5-flash-image",
|
||||||
|
"models/gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range required {
|
||||||
|
model, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected fallback model %q to exist", name)
|
||||||
|
}
|
||||||
|
if len(model.SupportedGenerationMethods) == 0 {
|
||||||
|
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,10 +13,12 @@ type Model struct {
|
|||||||
var DefaultModels = []Model{
|
var DefaultModels = []Model{
|
||||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||||
|
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||||
|
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultTestModel is the default model to preselect in test flows.
|
// DefaultTestModel is the default model to preselect in test flows.
|
||||||
|
|||||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
byID := make(map[string]Model, len(DefaultModels))
|
||||||
|
for _, model := range DefaultModels {
|
||||||
|
byID[model.ID] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range required {
|
||||||
|
if _, ok := byID[id]; !ok {
|
||||||
|
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -50,6 +51,18 @@ type accountRepository struct {
|
|||||||
schedulerCache service.SchedulerCache
|
schedulerCache service.SchedulerCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||||
|
"codex_primary_",
|
||||||
|
"codex_secondary_",
|
||||||
|
"codex_5h_",
|
||||||
|
"codex_7d_",
|
||||||
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
|
"codex_usage_updated_at": {},
|
||||||
|
"session_window_utilization": {},
|
||||||
|
}
|
||||||
|
|
||||||
// NewAccountRepository 创建账户仓储实例。
|
// NewAccountRepository 创建账户仓储实例。
|
||||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||||
@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||||
|
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||||
|
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for key := range updates {
|
||||||
|
if isSchedulerNeutralExtraKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSchedulerNeutralExtraKey(key string) bool {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||||
|
if strings.HasPrefix(key, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
|||||||
|
|
||||||
type schedulerCacheRecorder struct {
|
type schedulerCacheRecorder struct {
|
||||||
setAccounts []*service.Account
|
setAccounts []*service.Account
|
||||||
|
accounts map[int64]*service.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||||
return nil, nil
|
if s.accounts == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.accounts[accountID], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||||
s.setAccounts = append(s.setAccounts, account)
|
s.setAccounts = append(s.setAccounts, account)
|
||||||
|
if s.accounts == nil {
|
||||||
|
s.accounts = make(map[int64]*service.Account)
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
s.accounts[account.ID] = account
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
|||||||
s.Require().Equal("val", got.Extra["key"])
|
s.Require().Equal("val", got.Extra["key"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-neutral",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{
|
||||||
|
accounts: map[int64]*service.Account{
|
||||||
|
account.ID: {
|
||||||
|
ID: account.ID,
|
||||||
|
Platform: account.Platform,
|
||||||
|
Status: service.StatusDisabled,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_usage_updated_at": "old",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
updates := map[string]any{
|
||||||
|
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||||
|
"codex_5h_used_percent": 88.5,
|
||||||
|
"session_window_utilization": 0.42,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||||
|
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||||
|
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||||
|
|
||||||
|
var outboxCount int
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||||
|
s.Require().Zero(outboxCount)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||||
|
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||||
|
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-codex-exhausted",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||||
|
"codex_7d_reset_after_seconds": 86400,
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(0, count)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||||
|
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-mixed",
|
||||||
|
Platform: service.PlatformAntigravity,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(1, count)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetByCRSAccountID ---
|
// --- GetByCRSAccountID ---
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||||
|
|||||||
@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
|||||||
return updated.QuotaUsed, nil
|
return updated.QuotaUsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||||
|
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||||
|
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||||
|
query := `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET
|
||||||
|
quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $3 AND deleted_at IS NULL
|
||||||
|
RETURNING quota_used, quota, key, status
|
||||||
|
`
|
||||||
|
|
||||||
|
state := &service.APIKeyQuotaUsageState{}
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||||
affected, err := r.client.APIKey.Update().
|
affected, err := r.client.APIKey.Update().
|
||||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||||
|
|||||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
|||||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||||
|
user := s.mustCreateUser("quota-state@test.com")
|
||||||
|
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||||
|
key.Quota = 3
|
||||||
|
key.QuotaUsed = 1
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||||
|
|
||||||
|
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||||
|
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||||
|
s.Require().NotNil(state)
|
||||||
|
s.Require().Equal(3.5, state.QuotaUsed)
|
||||||
|
s.Require().Equal(3.0, state.Quota)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||||
|
s.Require().Equal(key.Key, state.Key)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(3.5, got.QuotaUsed)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||||
|
}
|
||||||
|
|
||||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||||
|
|||||||
@@ -16,19 +16,7 @@ type opsRepository struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
const insertOpsErrorLogSQL = `
|
||||||
return &opsRepository{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
|
||||||
if r == nil || r.db == nil {
|
|
||||||
return 0, fmt.Errorf("nil ops repository")
|
|
||||||
}
|
|
||||||
if input == nil {
|
|
||||||
return 0, fmt.Errorf("nil input")
|
|
||||||
}
|
|
||||||
|
|
||||||
q := `
|
|
||||||
INSERT INTO ops_error_logs (
|
INSERT INTO ops_error_logs (
|
||||||
request_id,
|
request_id,
|
||||||
client_request_id,
|
client_request_id,
|
||||||
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
|
|||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||||
) RETURNING id`
|
)`
|
||||||
|
|
||||||
|
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||||
|
return &opsRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return 0, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
|
||||||
var id int64
|
var id int64
|
||||||
err := r.db.QueryRowContext(
|
err := r.db.QueryRowContext(
|
||||||
ctx,
|
ctx,
|
||||||
q,
|
insertOpsErrorLogSQL+" RETURNING id",
|
||||||
|
opsInsertErrorLogArgs(input)...,
|
||||||
|
).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = stmt.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var inserted int64
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
|
||||||
|
return inserted, err
|
||||||
|
}
|
||||||
|
inserted++
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return inserted, err
|
||||||
|
}
|
||||||
|
return inserted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||||
|
return []any{
|
||||||
opsNullString(input.RequestID),
|
opsNullString(input.RequestID),
|
||||||
opsNullString(input.ClientRequestID),
|
opsNullString(input.ClientRequestID),
|
||||||
opsNullInt64(input.UserID),
|
opsNullInt64(input.UserID),
|
||||||
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
|
|||||||
input.IsRetryable,
|
input.IsRetryable,
|
||||||
input.RetryCount,
|
input.RetryCount,
|
||||||
input.CreatedAt,
|
input.CreatedAt,
|
||||||
).Scan(&id)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
return id, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||||
|
|
||||||
|
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||||
|
{
|
||||||
|
RequestID: "batch-ops-1",
|
||||||
|
ErrorPhase: "upstream",
|
||||||
|
ErrorType: "upstream_error",
|
||||||
|
Severity: "error",
|
||||||
|
StatusCode: 429,
|
||||||
|
ErrorMessage: "rate limited",
|
||||||
|
CreatedAt: now,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
RequestID: "batch-ops-2",
|
||||||
|
ErrorPhase: "internal",
|
||||||
|
ErrorType: "api_error",
|
||||||
|
Severity: "error",
|
||||||
|
StatusCode: 500,
|
||||||
|
ErrorMessage: "internal error",
|
||||||
|
CreatedAt: now.Add(time.Millisecond),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, inserted)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||||
|
|
||||||
|
accountID := int64(12345)
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||||
|
|
||||||
|
accountID := int64(67890)
|
||||||
|
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||||
|
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||||
|
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const schedulerOutboxDedupWindow = time.Second
|
||||||
|
|
||||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||||
return &schedulerOutboxRepository{db: db}
|
return &schedulerOutboxRepository{db: db}
|
||||||
}
|
}
|
||||||
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
|
|||||||
}
|
}
|
||||||
payloadArg = encoded
|
payloadArg = encoded
|
||||||
}
|
}
|
||||||
_, err := exec.ExecContext(ctx, `
|
query := `
|
||||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
`, eventType, accountID, groupID, payloadArg)
|
`
|
||||||
|
args := []any{eventType, accountID, groupID, payloadArg}
|
||||||
|
if schedulerOutboxEventSupportsDedup(eventType) {
|
||||||
|
query = `
|
||||||
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
|
SELECT $1, $2, $3, $4
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM scheduler_outbox
|
||||||
|
WHERE event_type = $1
|
||||||
|
AND account_id IS NOT DISTINCT FROM $2
|
||||||
|
AND group_id IS NOT DISTINCT FROM $3
|
||||||
|
AND created_at >= NOW() - make_interval(secs => $5)
|
||||||
|
)
|
||||||
|
`
|
||||||
|
args = append(args, schedulerOutboxDedupWindow.Seconds())
|
||||||
|
}
|
||||||
|
_, err := exec.ExecContext(ctx, query, args...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func schedulerOutboxEventSupportsDedup(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case service.SchedulerOutboxEventAccountChanged,
|
||||||
|
service.SchedulerOutboxEventGroupChanged,
|
||||||
|
service.SchedulerOutboxEventFullRebuild:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -456,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||||
|
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
|
||||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
|
|||||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
// OpenAI Chat Completions API
|
||||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
|
|||||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||||
|
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||||
|
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||||
|
|
||||||
// Antigravity 模型列表
|
// Antigravity 模型列表
|
||||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||||
|
|||||||
@@ -45,16 +45,23 @@ const (
|
|||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
type TestEvent struct {
|
type TestEvent struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
Data any `json:"data,omitempty"`
|
ImageURL string `json:"image_url,omitempty"`
|
||||||
Success bool `json:"success,omitempty"`
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Data any `json:"data,omitempty"`
|
||||||
|
Success bool `json:"success,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultGeminiTextTestPrompt = "hi"
|
||||||
|
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
|
||||||
|
)
|
||||||
|
|
||||||
// AccountTestService handles account testing operations
|
// AccountTestService handles account testing operations
|
||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
|||||||
// TestAccountConnection tests an account's connection by sending a test request
|
// TestAccountConnection tests an account's connection by sending a test request
|
||||||
// All account types use full Claude Code client characteristics, only auth header differs
|
// All account types use full Claude Code client characteristics, only auth header differs
|
||||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Get account
|
// Get account
|
||||||
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account.IsGemini() {
|
if account.IsGemini() {
|
||||||
return s.testGeminiAccountConnection(c, account, modelID)
|
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
return s.routeAntigravityTest(c, account, modelID)
|
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformSora {
|
if account.Platform == PlatformSora {
|
||||||
@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// testGeminiAccountConnection tests a Gemini account's connection
|
// testGeminiAccountConnection tests a Gemini account's connection
|
||||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Determine the model to use
|
// Determine the model to use
|
||||||
@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
// Create test payload (Gemini format)
|
// Create test payload (Gemini format)
|
||||||
payload := createGeminiTestPayload()
|
payload := createGeminiTestPayload(testModelID, prompt)
|
||||||
|
|
||||||
// Build request based on account type
|
// Build request based on account type
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
|||||||
|
|
||||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
if strings.HasPrefix(modelID, "gemini-") {
|
if strings.HasPrefix(modelID, "gemini-") {
|
||||||
return s.testGeminiAccountConnection(c, account, modelID)
|
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
return s.testClaudeAccountConnection(c, account, modelID)
|
return s.testClaudeAccountConnection(c, account, modelID)
|
||||||
}
|
}
|
||||||
@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
// createGeminiTestPayload creates a minimal test payload for Gemini API.
|
||||||
func createGeminiTestPayload() []byte {
|
// Image models use the image-generation path so the frontend can preview the returned image.
|
||||||
|
func createGeminiTestPayload(modelID string, prompt string) []byte {
|
||||||
|
if isImageGenerationModel(modelID) {
|
||||||
|
imagePrompt := strings.TrimSpace(prompt)
|
||||||
|
if imagePrompt == "" {
|
||||||
|
imagePrompt = defaultGeminiImageTestPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": []map[string]any{
|
||||||
|
{"text": imagePrompt},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"generationConfig": map[string]any{
|
||||||
|
"responseModalities": []string{"TEXT", "IMAGE"},
|
||||||
|
"imageConfig": map[string]any{
|
||||||
|
"aspectRatio": "1:1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bytes, _ := json.Marshal(payload)
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
textPrompt := strings.TrimSpace(prompt)
|
||||||
|
if textPrompt == "" {
|
||||||
|
textPrompt = defaultGeminiTextTestPrompt
|
||||||
|
}
|
||||||
|
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"contents": []map[string]any{
|
"contents": []map[string]any{
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": []map[string]any{
|
"parts": []map[string]any{
|
||||||
{"text": "hi"},
|
{"text": textPrompt},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
|||||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
}
|
}
|
||||||
|
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
|
||||||
|
mimeType, _ := inlineData["mimeType"].(string)
|
||||||
|
data, _ := inlineData["data"].(string)
|
||||||
|
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
|
||||||
|
s.sendEvent(c, TestEvent{
|
||||||
|
Type: "image",
|
||||||
|
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
|
||||||
|
MimeType: mimeType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
|
|||||||
ginCtx, _ := gin.CreateTestContext(w)
|
ginCtx, _ := gin.CreateTestContext(w)
|
||||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||||
|
|
||||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
|
||||||
|
|
||||||
finishedAt := time.Now()
|
finishedAt := time.Now()
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
|
|||||||
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
|
||||||
|
|
||||||
|
var parsed struct {
|
||||||
|
Contents []struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"contents"`
|
||||||
|
GenerationConfig struct {
|
||||||
|
ResponseModalities []string `json:"responseModalities"`
|
||||||
|
ImageConfig struct {
|
||||||
|
AspectRatio string `json:"aspectRatio"`
|
||||||
|
} `json:"imageConfig"`
|
||||||
|
} `json:"generationConfig"`
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, json.Unmarshal(payload, &parsed))
|
||||||
|
require.Len(t, parsed.Contents, 1)
|
||||||
|
require.Len(t, parsed.Contents[0].Parts, 1)
|
||||||
|
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
|
||||||
|
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
|
||||||
|
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
ctx, recorder := newSoraTestContext()
|
||||||
|
svc := &AccountTestService{}
|
||||||
|
|
||||||
|
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||||
|
|
||||||
|
err := svc.processGeminiStream(ctx, stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := recorder.Body.String()
|
||||||
|
require.Contains(t, body, "\"type\":\"content\"")
|
||||||
|
require.Contains(t, body, "\"text\":\"ok\"")
|
||||||
|
require.Contains(t, body, "\"type\":\"image\"")
|
||||||
|
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
|
||||||
|
require.Contains(t, body, "\"mime_type\":\"image/png\"")
|
||||||
|
}
|
||||||
@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
||||||
mergeAccountExtra(account, updates)
|
mergeAccountExtra(account, updates)
|
||||||
|
if resetAt != nil {
|
||||||
|
account.RateLimitResetAt = resetAt
|
||||||
|
}
|
||||||
if usage.UpdatedAt == nil {
|
if usage.UpdatedAt == nil {
|
||||||
usage.UpdatedAt = &now
|
usage.UpdatedAt = &now
|
||||||
}
|
}
|
||||||
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
||||||
if account == nil || !account.IsOAuth() {
|
if account == nil || !account.IsOAuth() {
|
||||||
return nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
accessToken := account.GetOpenAIAccessToken()
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("no access token available")
|
return nil, nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
modelID := openaipkg.DefaultTestModel
|
modelID := openaipkg.DefaultTestModel
|
||||||
payload := createOpenAITestPayload(modelID, true)
|
payload := createOpenAITestPayload(modelID, true)
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
||||||
}
|
}
|
||||||
req.Host = "chatgpt.com"
|
req.Host = "chatgpt.com"
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
|||||||
ResponseHeaderTimeout: 10 * time.Second,
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if len(updates) > 0 {
|
if len(updates) > 0 || resetAt != nil {
|
||||||
go func(accountID int64, updates map[string]any) {
|
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
return updates, resetAt, nil
|
||||||
defer updateCancel()
|
}
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
||||||
|
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(updates) == 0 && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer updateCancel()
|
||||||
|
if len(updates) > 0 {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}(account.ID, updates)
|
}
|
||||||
return updates, nil
|
if resetAt != nil {
|
||||||
|
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
||||||
|
if resp == nil {
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
return nil, nil
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
|
baseTime := time.Now()
|
||||||
|
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
||||||
|
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
||||||
|
if len(updates) > 0 {
|
||||||
|
return updates, resetAt, nil
|
||||||
|
}
|
||||||
|
return nil, resetAt, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||||
if resp == nil {
|
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
return nil, nil
|
return updates, err
|
||||||
}
|
|
||||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|
||||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
|
||||||
if len(updates) > 0 {
|
|
||||||
return updates, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||||
|
|||||||
@@ -1,11 +1,36 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type accountUsageCodexProbeRepo struct {
|
||||||
|
stubOpenAIAccountRepo
|
||||||
|
updateExtraCh chan map[string]any
|
||||||
|
rateLimitCh chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||||
|
if r.updateExtraCh != nil {
|
||||||
|
copied := make(map[string]any, len(updates))
|
||||||
|
for k, v := range updates {
|
||||||
|
copied[k] = v
|
||||||
|
}
|
||||||
|
r.updateExtraCh <- copied
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||||
|
if r.rateLimitCh != nil {
|
||||||
|
r.rateLimitCh <- resetAt
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
|
|||||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("x-codex-primary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||||
|
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||||
|
headers.Set("x-codex-secondary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||||
|
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||||
|
|
||||||
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(updates) == 0 {
|
||||||
|
t.Fatal("expected codex probe updates from 429 headers")
|
||||||
|
}
|
||||||
|
if resetAt == nil {
|
||||||
|
t.Fatal("expected resetAt from exhausted codex headers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repo := &accountUsageCodexProbeRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 1),
|
||||||
|
rateLimitCh: make(chan time.Time, 1),
|
||||||
|
}
|
||||||
|
svc := &AccountUsageService{accountRepo: repo}
|
||||||
|
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||||
|
|
||||||
|
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||||
|
}, &resetAt)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||||
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe extra persistence timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-repo.rateLimitCh:
|
||||||
|
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
||||||
|
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2164,6 +2164,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
|
||||||
|
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
|
||||||
|
signatureCheckBody := respBody
|
||||||
|
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
|
||||||
|
signatureCheckBody = unwrapped
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusBadRequest &&
|
||||||
|
s.settingService != nil &&
|
||||||
|
s.settingService.IsSignatureRectifierEnabled(ctx) &&
|
||||||
|
isSignatureRelatedError(signatureCheckBody) &&
|
||||||
|
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
|
||||||
|
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
|
||||||
|
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||||
|
|
||||||
|
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||||
|
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||||
|
if wrapErr == nil {
|
||||||
|
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: ctx,
|
||||||
|
prefix: prefix,
|
||||||
|
account: account,
|
||||||
|
proxyURL: proxyURL,
|
||||||
|
accessToken: accessToken,
|
||||||
|
action: upstreamAction,
|
||||||
|
body: retryWrappedBody,
|
||||||
|
c: c,
|
||||||
|
httpUpstream: s.httpUpstream,
|
||||||
|
settingService: s.settingService,
|
||||||
|
accountRepo: s.accountRepo,
|
||||||
|
handleError: s.handleUpstreamError,
|
||||||
|
requestedModel: originalModel,
|
||||||
|
isStickySession: isStickySession,
|
||||||
|
groupID: 0,
|
||||||
|
sessionHash: "",
|
||||||
|
})
|
||||||
|
if retryErr == nil {
|
||||||
|
retryResp := retryResult.resp
|
||||||
|
if retryResp.StatusCode < 400 {
|
||||||
|
resp = retryResp
|
||||||
|
} else {
|
||||||
|
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||||
|
_ = retryResp.Body.Close()
|
||||||
|
retryOpsBody := retryRespBody
|
||||||
|
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
|
||||||
|
retryOpsBody = retryUnwrapped
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_retry",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
|
||||||
|
Detail: s.getUpstreamErrorDetail(retryOpsBody),
|
||||||
|
})
|
||||||
|
respBody = retryRespBody
|
||||||
|
resp = &http.Response{
|
||||||
|
StatusCode: retryResp.StatusCode,
|
||||||
|
Header: retryResp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
|
||||||
|
}
|
||||||
|
contentType = resp.Header.Get("Content-Type")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: http.StatusServiceUnavailable,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
ForceCacheBilling: switchErr.IsStickySession,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "signature_retry_request_error",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
|
})
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// fallback 成功:继续按正常响应处理
|
// fallback 成功:继续按正常响应处理
|
||||||
if resp.StatusCode < 400 {
|
if resp.StatusCode < 400 {
|
||||||
goto handleSuccess
|
goto handleSuccess
|
||||||
|
|||||||
@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
|||||||
return s.resp, s.err
|
return s.resp, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type queuedHTTPUpstreamStub struct {
|
||||||
|
responses []*http.Response
|
||||||
|
errors []error
|
||||||
|
requestBodies [][]byte
|
||||||
|
callCount int
|
||||||
|
onCall func(*http.Request, *queuedHTTPUpstreamStub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
if req != nil && req.Body != nil {
|
||||||
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
s.requestBodies = append(s.requestBodies, body)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
} else {
|
||||||
|
s.requestBodies = append(s.requestBodies, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := s.callCount
|
||||||
|
s.callCount++
|
||||||
|
if s.onCall != nil {
|
||||||
|
s.onCall(req, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
if idx < len(s.responses) {
|
||||||
|
resp = s.responses[idx]
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if idx < len(s.errors) {
|
||||||
|
err = s.errors[idx]
|
||||||
|
}
|
||||||
|
if resp == nil && err == nil {
|
||||||
|
return nil, errors.New("unexpected upstream call")
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
|
||||||
|
return s.Do(req, proxyURL, accountID, concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
type antigravitySettingRepoStub struct{}
|
type antigravitySettingRepoStub struct{}
|
||||||
|
|
||||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
|||||||
require.Equal(t, mappedModel, result.Model)
|
require.Equal(t, mappedModel, result.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||||
|
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req-sig-1"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req-sig-2"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
const originalModel = "gemini-3.1-pro-preview"
|
||||||
|
const mappedModel = "gemini-3.1-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "acc-gemini-signature",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
originalModel: mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||||
|
|
||||||
|
firstReq := string(upstream.requestBodies[0])
|
||||||
|
secondReq := string(upstream.requestBodies[1])
|
||||||
|
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
|
||||||
|
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
|
||||||
|
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
|
||||||
|
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
|
||||||
|
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, events)
|
||||||
|
require.Equal(t, "signature_error", events[0].Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||||
|
|
||||||
|
const originalModel = "gemini-3.1-pro-preview"
|
||||||
|
const mappedModel = "gemini-3.1-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 8,
|
||||||
|
Name: "acc-gemini-signature-failover",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
originalModel: mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req-sig-failover-1"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
|
||||||
|
if stub.callCount != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||||||
|
account.Extra = map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
mappedModel: map[string]any{
|
||||||
|
"rate_limit_reset_at": futureResetAt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
|
||||||
|
require.Nil(t, result)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||||
|
require.True(t, failoverErr.ForceCacheBilling)
|
||||||
|
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
require.Equal(t, "signature_error", events[0].Kind)
|
||||||
|
require.Equal(t, "failover", events[1].Kind)
|
||||||
|
}
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
|||||||
return d.Usage7d
|
return d.Usage7d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
|
||||||
|
// It is intentionally small so repositories can return it from a single SQL statement.
|
||||||
|
type APIKeyQuotaUsageState struct {
|
||||||
|
QuotaUsed float64
|
||||||
|
Quota float64
|
||||||
|
Key string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
type APIKeyCache interface {
|
type APIKeyCache interface {
|
||||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type quotaStateReader interface {
|
||||||
|
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
|
||||||
|
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("increment quota used: %w", err)
|
||||||
|
}
|
||||||
|
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, state.Key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Use repository to atomically increment quota_used
|
// Use repository to atomically increment quota_used
|
||||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
170
backend/internal/service/api_key_service_quota_test.go
Normal file
170
backend/internal/service/api_key_service_quota_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type quotaStateRepoStub struct {
|
||||||
|
quotaBaseAPIKeyRepoStub
|
||||||
|
stateCalls int
|
||||||
|
state *APIKeyQuotaUsageState
|
||||||
|
stateErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
|
||||||
|
s.stateCalls++
|
||||||
|
if s.stateErr != nil {
|
||||||
|
return nil, s.stateErr
|
||||||
|
}
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := *s.state
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaStateCacheStub struct {
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaBaseAPIKeyRepoStub struct {
|
||||||
|
getByIDCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
|
||||||
|
s.getByIDCalls++
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||||
|
panic("unexpected IncrementQuotaUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||||
|
panic("unexpected UpdateLastUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||||
|
panic("unexpected IncrementRateLimitUsage call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
|
||||||
|
panic("unexpected ResetRateLimitWindows call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||||
|
panic("unexpected GetRateLimitData call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
|
||||||
|
repo := "aStateRepoStub{
|
||||||
|
state: &APIKeyQuotaUsageState{
|
||||||
|
QuotaUsed: 12,
|
||||||
|
Quota: 10,
|
||||||
|
Key: "sk-test-quota",
|
||||||
|
Status: StatusAPIKeyQuotaExhausted,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := "aStateCacheStub{}
|
||||||
|
svc := &APIKeyService{
|
||||||
|
apiKeyRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.stateCalls)
|
||||||
|
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
|
||||||
|
}
|
||||||
@@ -5998,6 +5998,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
@@ -6187,6 +6203,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
lastDataAt = time.Now()
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
if firstTokenMs == nil && data != "[DONE]" {
|
if firstTokenMs == nil && data != "[DONE]" {
|
||||||
@@ -6220,6 +6237,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "hello"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
|
||||||
|
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"cachedContent": {
|
||||||
|
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
|
||||||
|
},
|
||||||
|
"signature": "keep_me"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(cleaned, &got))
|
||||||
|
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
|
||||||
|
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
|
||||||
|
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
|
||||||
|
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
|
||||||
|
input := []byte(`{"contents":[invalid-json]}`)
|
||||||
|
|
||||||
|
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||||
|
|
||||||
|
require.Equal(t, input, cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
|
||||||
|
input := map[string]any{
|
||||||
|
"thoughtSignature": "sig_root",
|
||||||
|
"signature": "keep_signature",
|
||||||
|
"nested": []any{
|
||||||
|
map[string]any{
|
||||||
|
"thoughtSignature": "sig_nested",
|
||||||
|
"signature": "keep_nested_signature",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
|
||||||
|
require.Equal(t, "keep_signature", got["signature"])
|
||||||
|
|
||||||
|
nested, ok := got["nested"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
nestedMap, ok := nested[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
|
||||||
|
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -146,6 +147,22 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
input = filterCodexInput(input, needsToolContinuation)
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
reqBody["input"] = input
|
reqBody["input"] = input
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
|
} else if inputStr, ok := reqBody["input"].(string); ok {
|
||||||
|
// ChatGPT codex endpoint requires input to be a list, not a string.
|
||||||
|
// Convert string input to the expected message array format.
|
||||||
|
trimmed := strings.TrimSpace(inputStr)
|
||||||
|
if trimmed != "" {
|
||||||
|
reqBody["input"] = []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": inputStr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reqBody["input"] = []any{}
|
||||||
|
}
|
||||||
|
result.Modified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -210,6 +227,29 @@ func normalizeCodexModel(model string) string {
|
|||||||
return "gpt-5.1"
|
return "gpt-5.1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SupportsVerbosity(model string) bool {
|
||||||
|
if !strings.HasPrefix(model, "gpt-") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var major, minor int
|
||||||
|
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
|
||||||
|
|
||||||
|
if major > 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if major < 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// gpt-5
|
||||||
|
if n == 1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return minor >= 3
|
||||||
|
}
|
||||||
|
|
||||||
func getNormalizedCodexModel(modelID string) string {
|
func getNormalizedCodexModel(modelID string) string {
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -249,6 +249,50 @@ func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *t
|
|||||||
require.Equal(t, "old instructions", instructions)
|
require.Equal(t, "old instructions", instructions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_StringInputConvertedToArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": "Hello, world!"}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
msg, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "message", msg["type"])
|
||||||
|
require.Equal(t, "user", msg["role"])
|
||||||
|
require.Equal(t, "Hello, world!", msg["content"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyStringInputBecomesEmptyArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": ""}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_WhitespaceStringInputBecomesEmptyArray(t *testing.T) {
|
||||||
|
reqBody := map[string]any{"model": "gpt-5.4", "input": " "}
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.4",
|
||||||
|
"input": "Run the tests",
|
||||||
|
"tools": []any{map[string]any{"type": "function", "name": "bash"}},
|
||||||
|
}
|
||||||
|
applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsInstructionsEmpty(t *testing.T) {
|
func TestIsInstructionsEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
512
backend/internal/service/openai_gateway_chat_completions.go
Normal file
512
backend/internal/service/openai_gateway_chat_completions.go
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||||||
|
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||||
|
// the response back to Chat Completions format. All account types (OAuth and API
|
||||||
|
// Key) go through the Responses API conversion path since the upstream only
|
||||||
|
// exposes the /v1/responses endpoint.
|
||||||
|
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
body []byte,
|
||||||
|
promptCacheKey string,
|
||||||
|
defaultMappedModel string,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// 1. Parse Chat Completions request
|
||||||
|
var chatReq apicompat.ChatCompletionsRequest
|
||||||
|
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||||||
|
}
|
||||||
|
originalModel := chatReq.Model
|
||||||
|
clientStream := chatReq.Stream
|
||||||
|
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||||||
|
|
||||||
|
// 2. Convert to Responses and forward
|
||||||
|
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||||||
|
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Model mapping
|
||||||
|
mappedModel := account.GetMappedModel(originalModel)
|
||||||
|
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||||
|
mappedModel = defaultMappedModel
|
||||||
|
}
|
||||||
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
|
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("original_model", originalModel),
|
||||||
|
zap.String("mapped_model", mappedModel),
|
||||||
|
zap.Bool("stream", clientStream),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||||
|
responsesBody, err := json.Marshal(responsesReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
var reqBody map[string]any
|
||||||
|
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||||
|
}
|
||||||
|
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
if codexResult.PromptCacheKey != "" {
|
||||||
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
|
} else if promptCacheKey != "" {
|
||||||
|
reqBody["prompt_cache_key"] = promptCacheKey
|
||||||
|
}
|
||||||
|
responsesBody, err = json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Get access token
|
||||||
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Build upstream request
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if promptCacheKey != "" {
|
||||||
|
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Send request
|
||||||
|
proxyURL := ""
|
||||||
|
if account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 8. Handle error response with failover
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
}
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Handle normal response
|
||||||
|
var result *OpenAIForwardResult
|
||||||
|
var handleErr error
|
||||||
|
if clientStream {
|
||||||
|
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
|
||||||
|
} else {
|
||||||
|
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||||
|
if handleErr == nil && result != nil {
|
||||||
|
if responsesReq.ServiceTier != "" {
|
||||||
|
st := responsesReq.ServiceTier
|
||||||
|
result.ServiceTier = &st
|
||||||
|
}
|
||||||
|
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||||||
|
re := responsesReq.Reasoning.Effort
|
||||||
|
result.ReasoningEffort = &re
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||||
|
if handleErr == nil && account.Type == AccountTypeOAuth {
|
||||||
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, handleErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||||||
|
// OpenAI Chat Completions error format.
|
||||||
|
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
|
||||||
|
// upstream, finds the terminal event, converts to a Chat Completions JSON
|
||||||
|
// response, and writes it to the client.
|
||||||
|
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
originalModel string,
|
||||||
|
mappedModel string,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
var finalResponse *apicompat.ResponsesResponse
|
||||||
|
var usage OpenAIUsage
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := line[6:]
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||||
|
event.Response != nil {
|
||||||
|
finalResponse = event.Response
|
||||||
|
if event.Response.Usage != nil {
|
||||||
|
usage = OpenAIUsage{
|
||||||
|
InputTokens: event.Response.Usage.InputTokens,
|
||||||
|
OutputTokens: event.Response.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if event.Response.Usage.InputTokensDetails != nil {
|
||||||
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn("openai chat_completions buffered: read error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalResponse == nil {
|
||||||
|
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||||||
|
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
|
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
|
||||||
|
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, chatResp)
|
||||||
|
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: mappedModel,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatStreamingResponse reads Responses SSE events from upstream,
|
||||||
|
// converts each to Chat Completions SSE chunks, and writes them to the client.
|
||||||
|
func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
originalModel string,
|
||||||
|
mappedModel string,
|
||||||
|
includeUsage bool,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
state := apicompat.NewResponsesEventToChatState()
|
||||||
|
state.Model = originalModel
|
||||||
|
state.IncludeUsage = includeUsage
|
||||||
|
|
||||||
|
var usage OpenAIUsage
|
||||||
|
var firstTokenMs *int
|
||||||
|
firstChunk := true
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: mappedModel,
|
||||||
|
Stream: true,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
processDataLine := func(payload string) bool {
|
||||||
|
if firstChunk {
|
||||||
|
firstChunk = false
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions stream: failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract usage from completion events
|
||||||
|
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||||
|
event.Response != nil && event.Response.Usage != nil {
|
||||||
|
usage = OpenAIUsage{
|
||||||
|
InputTokens: event.Response.Usage.InputTokens,
|
||||||
|
OutputTokens: event.Response.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if event.Response.Usage.InputTokensDetails != nil {
|
||||||
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
|
if err != nil {
|
||||||
|
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(chunks) > 0 {
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
|
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||||||
|
for _, chunk := range finalChunks {
|
||||||
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Send [DONE] sentinel
|
||||||
|
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||||
|
c.Writer.Flush()
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
handleScanErr := func(err error) {
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn("openai chat_completions stream: read error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine keepalive interval
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// No keepalive: fast synchronous path
|
||||||
|
if keepaliveInterval <= 0 {
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if processDataLine(line[6:]) {
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handleScanErr(scanner.Err())
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
|
||||||
|
// With keepalive: goroutine + channel + select
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev scanEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for scanner.Scan() {
|
||||||
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
_ = sendEvent(scanEvent{err: err})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
handleScanErr(ev.err)
|
||||||
|
return finalizeStream()
|
||||||
|
}
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
line := ev.line
|
||||||
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if processDataLine(line[6:]) {
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-keepaliveTicker.C:
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Send SSE comment as keepalive
|
||||||
|
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), nil
|
||||||
|
}
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
|
||||||
|
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Non-failover error: return Anthropic-formatted error to client
|
// Non-failover error: return Anthropic-formatted error to client
|
||||||
@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
|||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
account *Account,
|
account *Account,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
|
|
||||||
// Record upstream error details for ops logging
|
|
||||||
upstreamDetail := ""
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = 2048
|
|
||||||
}
|
|
||||||
upstreamDetail = truncateString(string(body), maxBytes)
|
|
||||||
}
|
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
||||||
|
|
||||||
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
|
|
||||||
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
|
||||||
c, account.Platform, resp.StatusCode, body,
|
|
||||||
http.StatusBadGateway, "api_error", "Upstream request failed",
|
|
||||||
); matched {
|
|
||||||
writeAnthropicError(c, status, errType, errMsg)
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
upstreamMsg = errMsg
|
|
||||||
}
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
errType := "api_error"
|
|
||||||
switch {
|
|
||||||
case resp.StatusCode == 400:
|
|
||||||
errType = "invalid_request_error"
|
|
||||||
case resp.StatusCode == 404:
|
|
||||||
errType = "not_found_error"
|
|
||||||
case resp.StatusCode == 429:
|
|
||||||
errType = "rate_limit_error"
|
|
||||||
case resp.StatusCode >= 500:
|
|
||||||
errType = "api_error"
|
|
||||||
}
|
|
||||||
|
|
||||||
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
|
|
||||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ const (
|
|||||||
openAIWSRetryJitterRatioDefault = 0.2
|
openAIWSRetryJitterRatioDefault = 0.2
|
||||||
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
||||||
codexCLIVersion = "0.104.0"
|
codexCLIVersion = "0.104.0"
|
||||||
|
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
|
||||||
|
openAICodexSnapshotPersistMinInterval = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAI allowed headers whitelist (for non-passthrough).
|
// OpenAI allowed headers whitelist (for non-passthrough).
|
||||||
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
|
|||||||
nonRetryableFastFallback atomic.Int64
|
nonRetryableFastFallback atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountWriteThrottle struct {
|
||||||
|
minInterval time.Duration
|
||||||
|
mu sync.Mutex
|
||||||
|
lastByID map[int64]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
|
||||||
|
return &accountWriteThrottle{
|
||||||
|
minInterval: minInterval,
|
||||||
|
lastByID: make(map[int64]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
|
||||||
|
if t == nil || id <= 0 || t.minInterval <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.lastByID[id] = now
|
||||||
|
|
||||||
|
if len(t.lastByID) > 4096 {
|
||||||
|
cutoff := now.Add(-4 * t.minInterval)
|
||||||
|
for accountID, writtenAt := range t.lastByID {
|
||||||
|
if writtenAt.Before(cutoff) {
|
||||||
|
delete(t.lastByID, accountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
|
||||||
|
|
||||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||||
type OpenAIGatewayService struct {
|
type OpenAIGatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -289,6 +331,7 @@ type OpenAIGatewayService struct {
|
|||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||||
|
codexSnapshotThrottle *accountWriteThrottle
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||||
@@ -329,17 +372,25 @@ func NewOpenAIGatewayService(
|
|||||||
nil,
|
nil,
|
||||||
"service.openai_gateway",
|
"service.openai_gateway",
|
||||||
),
|
),
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
openAITokenProvider: openAITokenProvider,
|
openAITokenProvider: openAITokenProvider,
|
||||||
toolCorrector: NewCodexToolCorrector(),
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
svc.logOpenAIWSModeBootstrap()
|
svc.logOpenAIWSModeBootstrap()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||||
|
if s != nil && s.codexSnapshotThrottle != nil {
|
||||||
|
return s.codexSnapshotThrottle
|
||||||
|
}
|
||||||
|
return defaultOpenAICodexSnapshotPersistThrottle
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||||
return &billingDeps{
|
return &billingDeps{
|
||||||
accountRepo: s.accountRepo,
|
accountRepo: s.accountRepo,
|
||||||
@@ -1716,6 +1767,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
bodyModified = true
|
bodyModified = true
|
||||||
markPatchSet("model", normalizedModel)
|
markPatchSet("model", normalizedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
||||||
|
// 确保高版本模型向低版本模型映射不报错
|
||||||
|
if !SupportsVerbosity(normalizedModel) {
|
||||||
|
if text, ok := reqBody["text"].(map[string]any); ok {
|
||||||
|
delete(text, "verbosity")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||||
@@ -2947,6 +3006,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
|||||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compatErrorWriter is the signature for format-specific error writers used by
|
||||||
|
// the compat paths (Chat Completions and Anthropic Messages).
|
||||||
|
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
|
||||||
|
|
||||||
|
// handleCompatErrorResponse is the shared non-failover error handler for the
|
||||||
|
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
|
||||||
|
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
|
||||||
|
// tracking, secondary failover) but delegates the final error write to the
|
||||||
|
// format-specific writer function.
|
||||||
|
func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
writeError compatErrorWriter,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
|
// Apply error passthrough rules
|
||||||
|
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
||||||
|
c, account.Platform, resp.StatusCode, body,
|
||||||
|
http.StatusBadGateway, "api_error", "Upstream request failed",
|
||||||
|
); matched {
|
||||||
|
writeError(c, status, errType, errMsg)
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
upstreamMsg = errMsg
|
||||||
|
}
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check custom error codes — if the account does not handle this status,
|
||||||
|
// return a generic error without exposing upstream details.
|
||||||
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track rate limits and decide whether to trigger secondary failover.
|
||||||
|
shouldDisable := false
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
||||||
|
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
kind := "http_error"
|
||||||
|
if shouldDisable {
|
||||||
|
kind = "failover"
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: kind,
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
if shouldDisable {
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: body,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map status code to error type and write response
|
||||||
|
errType := "api_error"
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode == 400:
|
||||||
|
errType = "invalid_request_error"
|
||||||
|
case resp.StatusCode == 404:
|
||||||
|
errType = "not_found_error"
|
||||||
|
case resp.StatusCode == 429:
|
||||||
|
errType = "rate_limit_error"
|
||||||
|
case resp.StatusCode >= 500:
|
||||||
|
errType = "api_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
writeError(c, resp.StatusCode, errType, upstreamMsg)
|
||||||
|
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
// openaiStreamingResult streaming response result
|
// openaiStreamingResult streaming response result
|
||||||
type openaiStreamingResult struct {
|
type openaiStreamingResult struct {
|
||||||
usage *OpenAIUsage
|
usage *OpenAIUsage
|
||||||
@@ -4050,11 +4223,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
if len(updates) == 0 && resetAt == nil {
|
if len(updates) == 0 && resetAt == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
|
||||||
|
if !shouldPersistUpdates && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if len(updates) > 0 {
|
if shouldPersistUpdates {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}
|
}
|
||||||
if resetAt != nil {
|
if resetAt != nil {
|
||||||
|
|||||||
@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
|
||||||
|
repo := &openAICodexSnapshotAsyncRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 2),
|
||||||
|
rateLimitCh: make(chan time.Time, 2),
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
|
||||||
|
}
|
||||||
|
snapshot := &OpenAICodexUsageSnapshot{
|
||||||
|
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||||
|
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||||
|
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||||
|
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||||
|
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||||
|
SecondaryWindowMinutes: ptrIntWS(300),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-repo.updateExtraCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("等待第一次 codex 快照落库超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
t.Fatalf("unexpected second codex snapshot write: %v", updates)
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||||
func ptrIntWS(v int) *int { return &v }
|
func ptrIntWS(v int) *int { return &v }
|
||||||
|
|
||||||
|
|||||||
@@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
|||||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
})), true
|
})), true
|
||||||
|
case "group_rate_limit_ratio":
|
||||||
|
if groupID == nil || *groupID <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if availability.Group == nil || availability.Group.TotalAccounts <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true
|
||||||
|
case "account_error_ratio":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
total := int64(len(availability.Accounts))
|
||||||
|
if total <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
|
})
|
||||||
|
return (float64(errorCount) / float64(total)) * 100, true
|
||||||
|
case "overload_account_count":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.IsOverloaded
|
||||||
|
})), true
|
||||||
}
|
}
|
||||||
|
|
||||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
type OpsRepository interface {
|
type OpsRepository interface {
|
||||||
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||||
|
BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
|
||||||
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
||||||
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
||||||
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
|
|
||||||
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
|
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
|
||||||
type opsRepoMock struct {
|
type opsRepoMock struct {
|
||||||
|
InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||||
|
BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
|
||||||
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
||||||
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
||||||
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
||||||
@@ -14,9 +16,19 @@ type opsRepoMock struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if m.InsertErrorLogFn != nil {
|
||||||
|
return m.InsertErrorLogFn(ctx, input)
|
||||||
|
}
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if m.BatchInsertErrorLogsFn != nil {
|
||||||
|
return m.BatchInsertErrorLogsFn(ctx, inputs)
|
||||||
|
}
|
||||||
|
return int64(len(inputs)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
||||||
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
|
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
|
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
|
||||||
if entry == nil {
|
prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordError prepare failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil {
|
||||||
|
// Never bubble up to gateway; best-effort logging.
|
||||||
|
log.Printf("[Ops] RecordError failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
item, ok, err := s.prepareErrorLogInput(ctx, entry, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
prepared = append(prepared, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(prepared) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(prepared) == 1 {
|
||||||
|
_, err := s.opsRepo.InsertErrorLog(ctx, prepared[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err)
|
||||||
|
var firstErr error
|
||||||
|
for _, entry := range prepared {
|
||||||
|
if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil {
|
||||||
|
log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = insertErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) {
|
||||||
|
if entry == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
if !s.IsMonitoringEnabled(ctx) {
|
if !s.IsMonitoringEnabled(ctx) {
|
||||||
return nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
if s.opsRepo == nil {
|
if s.opsRepo == nil {
|
||||||
return nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure timestamps are always populated.
|
// Ensure timestamps are always populated.
|
||||||
@@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanitize + serialize upstream error events list.
|
if err := sanitizeOpsUpstreamErrors(entry); err != nil {
|
||||||
if len(entry.UpstreamErrors) > 0 {
|
return nil, false, err
|
||||||
const maxEvents = 32
|
}
|
||||||
events := entry.UpstreamErrors
|
|
||||||
if len(events) > maxEvents {
|
return entry, true, nil
|
||||||
events = events[len(events)-maxEvents:]
|
}
|
||||||
|
|
||||||
|
func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||||
|
if entry == nil || len(entry.UpstreamErrors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxEvents = 32
|
||||||
|
events := entry.UpstreamErrors
|
||||||
|
if len(events) > maxEvents {
|
||||||
|
events = events[len(events)-maxEvents:]
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
|
||||||
|
for _, ev := range events {
|
||||||
|
if ev == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out := *ev
|
||||||
|
|
||||||
|
out.Platform = strings.TrimSpace(out.Platform)
|
||||||
|
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||||
|
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||||
|
|
||||||
|
if out.AccountID < 0 {
|
||||||
|
out.AccountID = 0
|
||||||
|
}
|
||||||
|
if out.UpstreamStatusCode < 0 {
|
||||||
|
out.UpstreamStatusCode = 0
|
||||||
|
}
|
||||||
|
if out.AtUnixMs < 0 {
|
||||||
|
out.AtUnixMs = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
|
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
|
||||||
for _, ev := range events {
|
msg = truncateString(msg, 2048)
|
||||||
if ev == nil {
|
out.Message = msg
|
||||||
continue
|
|
||||||
}
|
|
||||||
out := *ev
|
|
||||||
|
|
||||||
out.Platform = strings.TrimSpace(out.Platform)
|
detail := strings.TrimSpace(out.Detail)
|
||||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
if detail != "" {
|
||||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
|
||||||
|
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
||||||
|
out.Detail = sanitizedDetail
|
||||||
|
} else {
|
||||||
|
out.Detail = ""
|
||||||
|
}
|
||||||
|
|
||||||
if out.AccountID < 0 {
|
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
||||||
out.AccountID = 0
|
if out.UpstreamRequestBody != "" {
|
||||||
}
|
// Reuse the same sanitization/trimming strategy as request body storage.
|
||||||
if out.UpstreamStatusCode < 0 {
|
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
||||||
out.UpstreamStatusCode = 0
|
sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
||||||
}
|
if sanitizedBody != "" {
|
||||||
if out.AtUnixMs < 0 {
|
out.UpstreamRequestBody = sanitizedBody
|
||||||
out.AtUnixMs = 0
|
if truncated {
|
||||||
}
|
out.Kind = strings.TrimSpace(out.Kind)
|
||||||
|
if out.Kind == "" {
|
||||||
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
|
out.Kind = "upstream"
|
||||||
msg = truncateString(msg, 2048)
|
|
||||||
out.Message = msg
|
|
||||||
|
|
||||||
detail := strings.TrimSpace(out.Detail)
|
|
||||||
if detail != "" {
|
|
||||||
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
|
|
||||||
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
|
||||||
out.Detail = sanitizedDetail
|
|
||||||
} else {
|
|
||||||
out.Detail = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
|
||||||
if out.UpstreamRequestBody != "" {
|
|
||||||
// Reuse the same sanitization/trimming strategy as request body storage.
|
|
||||||
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
|
||||||
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
|
||||||
if sanitized != "" {
|
|
||||||
out.UpstreamRequestBody = sanitized
|
|
||||||
if truncated {
|
|
||||||
out.Kind = strings.TrimSpace(out.Kind)
|
|
||||||
if out.Kind == "" {
|
|
||||||
out.Kind = "upstream"
|
|
||||||
}
|
|
||||||
out.Kind = out.Kind + ":request_body_truncated"
|
|
||||||
}
|
}
|
||||||
} else {
|
out.Kind = out.Kind + ":request_body_truncated"
|
||||||
out.UpstreamRequestBody = ""
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out.UpstreamRequestBody = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop fully-empty events (can happen if only status code was known).
|
|
||||||
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
evCopy := out
|
|
||||||
sanitized = append(sanitized, &evCopy)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
|
// Drop fully-empty events (can happen if only status code was known).
|
||||||
entry.UpstreamErrors = nil
|
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
evCopy := out
|
||||||
|
sanitized = append(sanitized, &evCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
|
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
|
||||||
// Never bubble up to gateway; best-effort logging.
|
entry.UpstreamErrors = nil
|
||||||
log.Printf("[Ops] RecordError failed: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
103
backend/internal/service/ops_service_batch_test.go
Normal file
103
backend/internal/service/ops_service_batch_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var captured []*OpsInsertErrorLogInput
|
||||||
|
repo := &opsRepoMock{
|
||||||
|
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
captured = append(captured, inputs...)
|
||||||
|
return int64(len(inputs)), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
msg := " upstream failed: https://example.com?access_token=secret-value "
|
||||||
|
detail := `{"authorization":"Bearer secret-token"}`
|
||||||
|
entries := []*OpsInsertErrorLogInput{
|
||||||
|
{
|
||||||
|
ErrorBody: `{"error":"bad","access_token":"secret"}`,
|
||||||
|
UpstreamStatusCode: intPtr(-10),
|
||||||
|
UpstreamErrorMessage: strPtr(msg),
|
||||||
|
UpstreamErrorDetail: strPtr(detail),
|
||||||
|
UpstreamErrors: []*OpsUpstreamErrorEvent{
|
||||||
|
{
|
||||||
|
AccountID: -2,
|
||||||
|
UpstreamStatusCode: 429,
|
||||||
|
Message: " token leaked ",
|
||||||
|
Detail: `{"refresh_token":"secret"}`,
|
||||||
|
UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ErrorPhase: "upstream",
|
||||||
|
ErrorType: "upstream_error",
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, svc.RecordErrorBatch(context.Background(), entries))
|
||||||
|
require.Len(t, captured, 2)
|
||||||
|
|
||||||
|
first := captured[0]
|
||||||
|
require.Equal(t, "internal", first.ErrorPhase)
|
||||||
|
require.Equal(t, "api_error", first.ErrorType)
|
||||||
|
require.Nil(t, first.UpstreamStatusCode)
|
||||||
|
require.NotNil(t, first.UpstreamErrorMessage)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorMessage, "secret-value")
|
||||||
|
require.Contains(t, *first.UpstreamErrorMessage, "access_token=***")
|
||||||
|
require.NotNil(t, first.UpstreamErrorDetail)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorDetail, "secret-token")
|
||||||
|
require.NotContains(t, first.ErrorBody, "secret")
|
||||||
|
require.Nil(t, first.UpstreamErrors)
|
||||||
|
require.NotNil(t, first.UpstreamErrorsJSON)
|
||||||
|
require.NotContains(t, *first.UpstreamErrorsJSON, "secret")
|
||||||
|
require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]")
|
||||||
|
|
||||||
|
second := captured[1]
|
||||||
|
require.Equal(t, "upstream", second.ErrorPhase)
|
||||||
|
require.Equal(t, "upstream_error", second.ErrorType)
|
||||||
|
require.False(t, second.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
batchCalls int
|
||||||
|
singleCalls int
|
||||||
|
)
|
||||||
|
repo := &opsRepoMock{
|
||||||
|
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
batchCalls++
|
||||||
|
return 0, errors.New("batch failed")
|
||||||
|
},
|
||||||
|
InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
singleCalls++
|
||||||
|
return int64(singleCalls), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{
|
||||||
|
{ErrorMessage: "first"},
|
||||||
|
{ErrorMessage: "second"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, batchCalls)
|
||||||
|
require.Equal(t, 2, singleCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(v string) *string {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
166
backend/internal/service/subscription_reset_quota_test.go
Normal file
166
backend/internal/service/subscription_reset_quota_test.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage,
|
||||||
|
// 其余方法继承 userSubRepoNoop(panic)。
|
||||||
|
type resetQuotaUserSubRepoStub struct {
|
||||||
|
userSubRepoNoop
|
||||||
|
|
||||||
|
sub *UserSubscription
|
||||||
|
|
||||||
|
resetDailyCalled bool
|
||||||
|
resetWeeklyCalled bool
|
||||||
|
resetDailyErr error
|
||||||
|
resetWeeklyErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
|
||||||
|
if r.sub == nil || r.sub.ID != id {
|
||||||
|
return nil, ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
cp := *r.sub
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error {
|
||||||
|
r.resetDailyCalled = true
|
||||||
|
if r.resetDailyErr == nil && r.sub != nil {
|
||||||
|
r.sub.DailyUsageUSD = 0
|
||||||
|
r.sub.DailyWindowStart = &windowStart
|
||||||
|
}
|
||||||
|
return r.resetDailyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error {
|
||||||
|
r.resetWeeklyCalled = true
|
||||||
|
return r.resetWeeklyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService {
|
||||||
|
return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetBoth(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 1, true, true)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||||
|
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 2, true, false)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||||
|
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 3, false, true)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
|
||||||
|
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20},
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 7, false, false)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, ErrInvalidInput)
|
||||||
|
require.False(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{sub: nil}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 999, true, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, ErrSubscriptionNotFound)
|
||||||
|
require.False(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
|
||||||
|
dbErr := errors.New("db error")
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20},
|
||||||
|
resetDailyErr: dbErr,
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 4, true, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, dbErr)
|
||||||
|
require.True(t, stub.resetDailyCalled)
|
||||||
|
require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) {
|
||||||
|
dbErr := errors.New("db error")
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20},
|
||||||
|
resetWeeklyErr: dbErr,
|
||||||
|
}
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
|
||||||
|
_, err := svc.AdminResetQuota(context.Background(), 5, false, true)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, dbErr)
|
||||||
|
require.True(t, stub.resetWeeklyCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
|
||||||
|
stub := &resetQuotaUserSubRepoStub{
|
||||||
|
sub: &UserSubscription{
|
||||||
|
ID: 6,
|
||||||
|
UserID: 10,
|
||||||
|
GroupID: 20,
|
||||||
|
DailyUsageUSD: 99.9,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := newResetQuotaSvc(stub)
|
||||||
|
result, err := svc.AdminResetQuota(context.Background(), 6, true, false)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
// ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零,
|
||||||
|
// 服务应返回第二次 GetByID 的刷新值而非初始的 99.9
|
||||||
|
require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量")
|
||||||
|
require.True(t, stub.resetDailyCalled)
|
||||||
|
}
|
||||||
@@ -31,6 +31,7 @@ var (
|
|||||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||||
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
||||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||||
|
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
|
||||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||||
@@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
|
|||||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminResetQuota manually resets the daily and/or weekly usage windows.
|
||||||
|
// Uses startOfDay(now) as the new window start, matching automatic resets.
|
||||||
|
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
|
||||||
|
if !resetDaily && !resetWeekly {
|
||||||
|
return nil, ErrInvalidInput
|
||||||
|
}
|
||||||
|
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
windowStart := startOfDay(time.Now())
|
||||||
|
if resetDaily {
|
||||||
|
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resetWeekly {
|
||||||
|
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Invalidate caches, same as CheckAndResetWindows
|
||||||
|
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||||
|
if s.billingCacheService != nil {
|
||||||
|
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||||
|
}
|
||||||
|
// Return the refreshed subscription from DB
|
||||||
|
return s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckAndResetWindows 检查并重置过期的窗口
|
// CheckAndResetWindows 检查并重置过期的窗口
|
||||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
|
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
|
||||||
// 使用当天零点作为新窗口起始时间
|
// 使用当天零点作为新窗口起始时间
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
-- Add gemini-2.5-flash-image aliases to Antigravity model_mapping
|
||||||
|
--
|
||||||
|
-- Background:
|
||||||
|
-- Gemini native image generation now relies on gemini-2.5-flash-image, and
|
||||||
|
-- existing Antigravity accounts with persisted model_mapping need this alias in
|
||||||
|
-- order to participate in mixed scheduling from gemini groups.
|
||||||
|
--
|
||||||
|
-- Strategy:
|
||||||
|
-- Overwrite the stored model_mapping so it matches DefaultAntigravityModelMapping
|
||||||
|
-- in constants.go, including legacy gemini-3-pro-image aliases.
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET credentials = jsonb_set(
|
||||||
|
credentials,
|
||||||
|
'{model_mapping}',
|
||||||
|
'{
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||||
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
|
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||||
|
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||||
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||||
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
|
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||||
|
}'::jsonb
|
||||||
|
)
|
||||||
|
WHERE platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND credentials->'model_mapping' IS NOT NULL;
|
||||||
@@ -120,6 +120,23 @@ export async function revoke(id: number): Promise<{ message: string }> {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset daily and/or weekly usage quota for a subscription
|
||||||
|
* @param id - Subscription ID
|
||||||
|
* @param options - Which windows to reset
|
||||||
|
* @returns Updated subscription
|
||||||
|
*/
|
||||||
|
export async function resetQuota(
|
||||||
|
id: number,
|
||||||
|
options: { daily: boolean; weekly: boolean }
|
||||||
|
): Promise<UserSubscription> {
|
||||||
|
const { data } = await apiClient.post<UserSubscription>(
|
||||||
|
`/admin/subscriptions/${id}/reset-quota`,
|
||||||
|
options
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List subscriptions by group
|
* List subscriptions by group
|
||||||
* @param groupId - Group ID
|
* @param groupId - Group ID
|
||||||
@@ -170,6 +187,7 @@ export const subscriptionsAPI = {
|
|||||||
bulkAssign,
|
bulkAssign,
|
||||||
extend,
|
extend,
|
||||||
revoke,
|
revoke,
|
||||||
|
resetQuota,
|
||||||
listByGroup,
|
listByGroup,
|
||||||
listByUser
|
listByUser
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ const formatScopeName = (scope: string): string => {
|
|||||||
'gemini-2.5-flash-lite': 'G25FL',
|
'gemini-2.5-flash-lite': 'G25FL',
|
||||||
'gemini-2.5-flash-thinking': 'G25FT',
|
'gemini-2.5-flash-thinking': 'G25FT',
|
||||||
'gemini-2.5-pro': 'G25P',
|
'gemini-2.5-pro': 'G25P',
|
||||||
|
'gemini-2.5-flash-image': 'G25I',
|
||||||
// Gemini 3 系列
|
// Gemini 3 系列
|
||||||
'gemini-3-flash': 'G3F',
|
'gemini-3-flash': 'G3F',
|
||||||
'gemini-3.1-pro-high': 'G3PH',
|
'gemini-3.1-pro-high': 'G3PH',
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
<div
|
<div
|
||||||
class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-primary-500 to-primary-600"
|
class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-primary-500 to-primary-600"
|
||||||
>
|
>
|
||||||
<Icon name="userCircle" size="md" class="text-white" :stroke-width="2" />
|
<Icon name="play" size="md" class="text-white" :stroke-width="2" />
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<div class="font-semibold text-gray-900 dark:text-gray-100">{{ account.name }}</div>
|
<div class="font-semibold text-gray-900 dark:text-gray-100">{{ account.name }}</div>
|
||||||
@@ -61,6 +61,17 @@
|
|||||||
{{ t('admin.accounts.soraTestHint') }}
|
{{ t('admin.accounts.soraTestHint') }}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div v-if="supportsGeminiImageTest" class="space-y-1.5">
|
||||||
|
<TextArea
|
||||||
|
v-model="testPrompt"
|
||||||
|
:label="t('admin.accounts.geminiImagePromptLabel')"
|
||||||
|
:placeholder="t('admin.accounts.geminiImagePromptPlaceholder')"
|
||||||
|
:hint="t('admin.accounts.geminiImageTestHint')"
|
||||||
|
:disabled="status === 'connecting'"
|
||||||
|
rows="3"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Terminal Output -->
|
<!-- Terminal Output -->
|
||||||
<div class="group relative">
|
<div class="group relative">
|
||||||
<div
|
<div
|
||||||
@@ -69,25 +80,11 @@
|
|||||||
>
|
>
|
||||||
<!-- Status Line -->
|
<!-- Status Line -->
|
||||||
<div v-if="status === 'idle'" class="flex items-center gap-2 text-gray-500">
|
<div v-if="status === 'idle'" class="flex items-center gap-2 text-gray-500">
|
||||||
<Icon name="bolt" size="sm" :stroke-width="2" />
|
<Icon name="play" size="sm" :stroke-width="2" />
|
||||||
<span>{{ t('admin.accounts.readyToTest') }}</span>
|
<span>{{ t('admin.accounts.readyToTest') }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div v-else-if="status === 'connecting'" class="flex items-center gap-2 text-yellow-400">
|
<div v-else-if="status === 'connecting'" class="flex items-center gap-2 text-yellow-400">
|
||||||
<svg class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
<Icon name="refresh" size="sm" class="animate-spin" :stroke-width="2" />
|
||||||
<circle
|
|
||||||
class="opacity-25"
|
|
||||||
cx="12"
|
|
||||||
cy="12"
|
|
||||||
r="10"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="4"
|
|
||||||
></circle>
|
|
||||||
<path
|
|
||||||
class="opacity-75"
|
|
||||||
fill="currentColor"
|
|
||||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
|
||||||
></path>
|
|
||||||
</svg>
|
|
||||||
<span>{{ t('admin.accounts.connectingToApi') }}</span>
|
<span>{{ t('admin.accounts.connectingToApi') }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -106,21 +103,14 @@
|
|||||||
v-if="status === 'success'"
|
v-if="status === 'success'"
|
||||||
class="mt-3 flex items-center gap-2 border-t border-gray-700 pt-3 text-green-400"
|
class="mt-3 flex items-center gap-2 border-t border-gray-700 pt-3 text-green-400"
|
||||||
>
|
>
|
||||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
<Icon name="check" size="sm" :stroke-width="2" />
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<span>{{ t('admin.accounts.testCompleted') }}</span>
|
<span>{{ t('admin.accounts.testCompleted') }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
v-else-if="status === 'error'"
|
v-else-if="status === 'error'"
|
||||||
class="mt-3 flex items-center gap-2 border-t border-gray-700 pt-3 text-red-400"
|
class="mt-3 flex items-center gap-2 border-t border-gray-700 pt-3 text-red-400"
|
||||||
>
|
>
|
||||||
<Icon name="xCircle" size="sm" :stroke-width="2" />
|
<Icon name="x" size="sm" :stroke-width="2" />
|
||||||
<span>{{ errorMessage }}</span>
|
<span>{{ errorMessage }}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -132,21 +122,48 @@
|
|||||||
class="absolute right-2 top-2 rounded-lg bg-gray-800/80 p-1.5 text-gray-400 opacity-0 transition-all hover:bg-gray-700 hover:text-white group-hover:opacity-100"
|
class="absolute right-2 top-2 rounded-lg bg-gray-800/80 p-1.5 text-gray-400 opacity-0 transition-all hover:bg-gray-700 hover:text-white group-hover:opacity-100"
|
||||||
:title="t('admin.accounts.copyOutput')"
|
:title="t('admin.accounts.copyOutput')"
|
||||||
>
|
>
|
||||||
<Icon name="copy" size="sm" :stroke-width="2" />
|
<Icon name="link" size="sm" :stroke-width="2" />
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div v-if="generatedImages.length > 0" class="space-y-2">
|
||||||
|
<div class="text-xs font-medium text-gray-600 dark:text-gray-300">
|
||||||
|
{{ t('admin.accounts.geminiImagePreview') }}
|
||||||
|
</div>
|
||||||
|
<div class="grid gap-3 sm:grid-cols-2">
|
||||||
|
<a
|
||||||
|
v-for="(image, index) in generatedImages"
|
||||||
|
:key="`${image.url}-${index}`"
|
||||||
|
:href="image.url"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
class="overflow-hidden rounded-xl border border-gray-200 bg-white shadow-sm transition hover:border-primary-300 hover:shadow-md dark:border-dark-500 dark:bg-dark-700"
|
||||||
|
>
|
||||||
|
<img :src="image.url" :alt="`gemini-test-image-${index + 1}`" class="h-48 w-full object-cover" />
|
||||||
|
<div class="border-t border-gray-100 px-3 py-2 text-xs text-gray-500 dark:border-dark-500 dark:text-gray-300">
|
||||||
|
{{ image.mimeType || 'image/*' }}
|
||||||
|
</div>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Test Info -->
|
<!-- Test Info -->
|
||||||
<div class="flex items-center justify-between px-1 text-xs text-gray-500 dark:text-gray-400">
|
<div class="flex items-center justify-between px-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
<div class="flex items-center gap-3">
|
<div class="flex items-center gap-3">
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="cpu" size="sm" :stroke-width="2" />
|
<Icon name="grid" size="sm" :stroke-width="2" />
|
||||||
{{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
|
{{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="chatBubble" size="sm" :stroke-width="2" />
|
<Icon name="chat" size="sm" :stroke-width="2" />
|
||||||
{{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
|
{{
|
||||||
|
isSoraAccount
|
||||||
|
? t('admin.accounts.soraTestMode')
|
||||||
|
: supportsGeminiImageTest
|
||||||
|
? t('admin.accounts.geminiImageTestMode')
|
||||||
|
: t('admin.accounts.testPrompt')
|
||||||
|
}}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -174,54 +191,15 @@
|
|||||||
: 'bg-primary-500 text-white hover:bg-primary-600'
|
: 'bg-primary-500 text-white hover:bg-primary-600'
|
||||||
]"
|
]"
|
||||||
>
|
>
|
||||||
<svg
|
<Icon
|
||||||
v-if="status === 'connecting'"
|
v-if="status === 'connecting'"
|
||||||
class="h-4 w-4 animate-spin"
|
name="refresh"
|
||||||
fill="none"
|
size="sm"
|
||||||
viewBox="0 0 24 24"
|
class="animate-spin"
|
||||||
>
|
:stroke-width="2"
|
||||||
<circle
|
/>
|
||||||
class="opacity-25"
|
<Icon v-else-if="status === 'idle'" name="play" size="sm" :stroke-width="2" />
|
||||||
cx="12"
|
<Icon v-else name="refresh" size="sm" :stroke-width="2" />
|
||||||
cy="12"
|
|
||||||
r="10"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="4"
|
|
||||||
></circle>
|
|
||||||
<path
|
|
||||||
class="opacity-75"
|
|
||||||
fill="currentColor"
|
|
||||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
|
||||||
></path>
|
|
||||||
</svg>
|
|
||||||
<svg
|
|
||||||
v-else-if="status === 'idle'"
|
|
||||||
class="h-4 w-4"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z"
|
|
||||||
/>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<svg v-else class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<span>
|
<span>
|
||||||
{{
|
{{
|
||||||
status === 'connecting'
|
status === 'connecting'
|
||||||
@@ -242,7 +220,8 @@ import { computed, ref, watch, nextTick } from 'vue'
|
|||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import TextArea from '@/components/common/TextArea.vue'
|
||||||
|
import { Icon } from '@/components/icons'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { Account, ClaudeModel } from '@/types'
|
import type { Account, ClaudeModel } from '@/types'
|
||||||
@@ -255,6 +234,11 @@ interface OutputLine {
|
|||||||
class: string
|
class: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface PreviewImage {
|
||||||
|
url: string
|
||||||
|
mimeType?: string
|
||||||
|
}
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
show: boolean
|
show: boolean
|
||||||
account: Account | null
|
account: Account | null
|
||||||
@@ -271,15 +255,37 @@ const streamingContent = ref('')
|
|||||||
const errorMessage = ref('')
|
const errorMessage = ref('')
|
||||||
const availableModels = ref<ClaudeModel[]>([])
|
const availableModels = ref<ClaudeModel[]>([])
|
||||||
const selectedModelId = ref('')
|
const selectedModelId = ref('')
|
||||||
|
const testPrompt = ref('')
|
||||||
const loadingModels = ref(false)
|
const loadingModels = ref(false)
|
||||||
let eventSource: EventSource | null = null
|
let eventSource: EventSource | null = null
|
||||||
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
||||||
|
const generatedImages = ref<PreviewImage[]>([])
|
||||||
|
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
|
||||||
|
const supportsGeminiImageTest = computed(() => {
|
||||||
|
if (isSoraAccount.value) return false
|
||||||
|
const modelID = selectedModelId.value.toLowerCase()
|
||||||
|
if (!modelID.startsWith('gemini-') || !modelID.includes('-image')) return false
|
||||||
|
|
||||||
|
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
|
||||||
|
})
|
||||||
|
|
||||||
|
const sortTestModels = (models: ClaudeModel[]) => {
|
||||||
|
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
|
||||||
|
|
||||||
|
return [...models].sort((a, b) => {
|
||||||
|
const aPriority = priorityMap.get(a.id) ?? Number.MAX_SAFE_INTEGER
|
||||||
|
const bPriority = priorityMap.get(b.id) ?? Number.MAX_SAFE_INTEGER
|
||||||
|
if (aPriority !== bPriority) return aPriority - bPriority
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Load available models when modal opens
|
// Load available models when modal opens
|
||||||
watch(
|
watch(
|
||||||
() => props.show,
|
() => props.show,
|
||||||
async (newVal) => {
|
async (newVal) => {
|
||||||
if (newVal && props.account) {
|
if (newVal && props.account) {
|
||||||
|
testPrompt.value = ''
|
||||||
resetState()
|
resetState()
|
||||||
await loadAvailableModels()
|
await loadAvailableModels()
|
||||||
} else {
|
} else {
|
||||||
@@ -288,6 +294,12 @@ watch(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
watch(selectedModelId, () => {
|
||||||
|
if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
|
||||||
|
testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const loadAvailableModels = async () => {
|
const loadAvailableModels = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
if (props.account.platform === 'sora') {
|
if (props.account.platform === 'sora') {
|
||||||
@@ -300,17 +312,14 @@ const loadAvailableModels = async () => {
|
|||||||
loadingModels.value = true
|
loadingModels.value = true
|
||||||
selectedModelId.value = '' // Reset selection before loading
|
selectedModelId.value = '' // Reset selection before loading
|
||||||
try {
|
try {
|
||||||
availableModels.value = await adminAPI.accounts.getAvailableModels(props.account.id)
|
const models = await adminAPI.accounts.getAvailableModels(props.account.id)
|
||||||
|
availableModels.value = props.account.platform === 'gemini' || props.account.platform === 'antigravity'
|
||||||
|
? sortTestModels(models)
|
||||||
|
: models
|
||||||
// Default selection by platform
|
// Default selection by platform
|
||||||
if (availableModels.value.length > 0) {
|
if (availableModels.value.length > 0) {
|
||||||
if (props.account.platform === 'gemini') {
|
if (props.account.platform === 'gemini') {
|
||||||
const preferred =
|
selectedModelId.value = availableModels.value[0].id
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.0-flash') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.5-flash') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.5-pro') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
|
|
||||||
selectedModelId.value = preferred?.id || availableModels.value[0].id
|
|
||||||
} else {
|
} else {
|
||||||
// Try to select Sonnet as default, otherwise use first model
|
// Try to select Sonnet as default, otherwise use first model
|
||||||
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
||||||
@@ -332,6 +341,7 @@ const resetState = () => {
|
|||||||
outputLines.value = []
|
outputLines.value = []
|
||||||
streamingContent.value = ''
|
streamingContent.value = ''
|
||||||
errorMessage.value = ''
|
errorMessage.value = ''
|
||||||
|
generatedImages.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
@@ -385,7 +395,12 @@ const startTest = async () => {
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify(
|
body: JSON.stringify(
|
||||||
isSoraAccount.value ? {} : { model_id: selectedModelId.value }
|
isSoraAccount.value
|
||||||
|
? {}
|
||||||
|
: {
|
||||||
|
model_id: selectedModelId.value,
|
||||||
|
prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
|
||||||
|
}
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -436,6 +451,8 @@ const handleEvent = (event: {
|
|||||||
model?: string
|
model?: string
|
||||||
success?: boolean
|
success?: boolean
|
||||||
error?: string
|
error?: string
|
||||||
|
image_url?: string
|
||||||
|
mime_type?: string
|
||||||
}) => {
|
}) => {
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case 'test_start':
|
case 'test_start':
|
||||||
@@ -444,7 +461,11 @@ const handleEvent = (event: {
|
|||||||
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
||||||
}
|
}
|
||||||
addLine(
|
addLine(
|
||||||
isSoraAccount.value ? t('admin.accounts.soraTestingFlow') : t('admin.accounts.sendingTestMessage'),
|
isSoraAccount.value
|
||||||
|
? t('admin.accounts.soraTestingFlow')
|
||||||
|
: supportsGeminiImageTest.value
|
||||||
|
? t('admin.accounts.sendingGeminiImageRequest')
|
||||||
|
: t('admin.accounts.sendingTestMessage'),
|
||||||
'text-gray-400'
|
'text-gray-400'
|
||||||
)
|
)
|
||||||
addLine('', 'text-gray-300')
|
addLine('', 'text-gray-300')
|
||||||
@@ -458,6 +479,16 @@ const handleEvent = (event: {
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
|
case 'image':
|
||||||
|
if (event.image_url) {
|
||||||
|
generatedImages.value.push({
|
||||||
|
url: event.image_url,
|
||||||
|
mimeType: event.mime_type
|
||||||
|
})
|
||||||
|
addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
case 'test_complete':
|
case 'test_complete':
|
||||||
// Move streaming content to output lines
|
// Move streaming content to output lines
|
||||||
if (streamingContent.value) {
|
if (streamingContent.value) {
|
||||||
|
|||||||
@@ -521,7 +521,7 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(
|
|||||||
|
|
||||||
// Gemini Image from API
|
// Gemini Image from API
|
||||||
const antigravity3ImageUsageFromAPI = computed(() =>
|
const antigravity3ImageUsageFromAPI = computed(() =>
|
||||||
getAntigravityUsageFromAPI(['gemini-3.1-flash-image', 'gemini-3-pro-image'])
|
getAntigravityUsageFromAPI(['gemini-2.5-flash-image', 'gemini-3.1-flash-image', 'gemini-3-pro-image'])
|
||||||
)
|
)
|
||||||
|
|
||||||
// Claude from API (all Claude model variants)
|
// Claude from API (all Claude model variants)
|
||||||
|
|||||||
@@ -959,10 +959,11 @@ const allModels = [
|
|||||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||||
|
{ value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' },
|
||||||
|
{ value: 'gemini-2.5-flash-image', label: 'Gemini 2.5 Flash Image' },
|
||||||
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
||||||
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
||||||
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
||||||
{ value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' },
|
|
||||||
{ value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' },
|
{ value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' },
|
||||||
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
||||||
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
||||||
@@ -1042,6 +1043,12 @@ const presetMappings = [
|
|||||||
to: 'claude-sonnet-4-5-20250929',
|
to: 'claude-sonnet-4-5-20250929',
|
||||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Gemini 2.5 Image',
|
||||||
|
from: 'gemini-2.5-flash-image',
|
||||||
|
to: 'gemini-2.5-flash-image',
|
||||||
|
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'Gemini 3.1 Image',
|
label: 'Gemini 3.1 Image',
|
||||||
from: 'gemini-3.1-flash-image',
|
from: 'gemini-3.1-flash-image',
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ describe('AccountUsageCell', () => {
|
|||||||
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
|
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
|
||||||
getUsage.mockResolvedValue({
|
getUsage.mockResolvedValue({
|
||||||
antigravity_quota: {
|
antigravity_quota: {
|
||||||
|
'gemini-2.5-flash-image': {
|
||||||
|
utilization: 45,
|
||||||
|
reset_time: '2026-03-01T11:00:00Z'
|
||||||
|
},
|
||||||
'gemini-3.1-flash-image': {
|
'gemini-3.1-flash-image': {
|
||||||
utilization: 20,
|
utilization: 20,
|
||||||
reset_time: '2026-03-01T10:00:00Z'
|
reset_time: '2026-03-01T10:00:00Z'
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ vi.mock('@/api/admin', () => ({
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/admin/accounts', () => ({
|
||||||
|
getAntigravityDefaultModelMapping: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
vi.mock('vue-i18n', async () => {
|
vi.mock('vue-i18n', async () => {
|
||||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -61,6 +61,17 @@
|
|||||||
{{ t('admin.accounts.soraTestHint') }}
|
{{ t('admin.accounts.soraTestHint') }}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div v-if="supportsGeminiImageTest" class="space-y-1.5">
|
||||||
|
<TextArea
|
||||||
|
v-model="testPrompt"
|
||||||
|
:label="t('admin.accounts.geminiImagePromptLabel')"
|
||||||
|
:placeholder="t('admin.accounts.geminiImagePromptPlaceholder')"
|
||||||
|
:hint="t('admin.accounts.geminiImageTestHint')"
|
||||||
|
:disabled="status === 'connecting'"
|
||||||
|
rows="3"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Terminal Output -->
|
<!-- Terminal Output -->
|
||||||
<div class="group relative">
|
<div class="group relative">
|
||||||
<div
|
<div
|
||||||
@@ -115,6 +126,27 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div v-if="generatedImages.length > 0" class="space-y-2">
|
||||||
|
<div class="text-xs font-medium text-gray-600 dark:text-gray-300">
|
||||||
|
{{ t('admin.accounts.geminiImagePreview') }}
|
||||||
|
</div>
|
||||||
|
<div class="grid gap-3 sm:grid-cols-2">
|
||||||
|
<a
|
||||||
|
v-for="(image, index) in generatedImages"
|
||||||
|
:key="`${image.url}-${index}`"
|
||||||
|
:href="image.url"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
class="overflow-hidden rounded-xl border border-gray-200 bg-white shadow-sm transition hover:border-primary-300 hover:shadow-md dark:border-dark-500 dark:bg-dark-700"
|
||||||
|
>
|
||||||
|
<img :src="image.url" :alt="`gemini-test-image-${index + 1}`" class="h-48 w-full object-cover" />
|
||||||
|
<div class="border-t border-gray-100 px-3 py-2 text-xs text-gray-500 dark:border-dark-500 dark:text-gray-300">
|
||||||
|
{{ image.mimeType || 'image/*' }}
|
||||||
|
</div>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Test Info -->
|
<!-- Test Info -->
|
||||||
<div class="flex items-center justify-between px-1 text-xs text-gray-500 dark:text-gray-400">
|
<div class="flex items-center justify-between px-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
<div class="flex items-center gap-3">
|
<div class="flex items-center gap-3">
|
||||||
@@ -125,7 +157,13 @@
|
|||||||
</div>
|
</div>
|
||||||
<span class="flex items-center gap-1">
|
<span class="flex items-center gap-1">
|
||||||
<Icon name="chat" size="sm" :stroke-width="2" />
|
<Icon name="chat" size="sm" :stroke-width="2" />
|
||||||
{{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
|
{{
|
||||||
|
isSoraAccount
|
||||||
|
? t('admin.accounts.soraTestMode')
|
||||||
|
: supportsGeminiImageTest
|
||||||
|
? t('admin.accounts.geminiImageTestMode')
|
||||||
|
: t('admin.accounts.testPrompt')
|
||||||
|
}}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -182,6 +220,7 @@ import { computed, ref, watch, nextTick } from 'vue'
|
|||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
|
import TextArea from '@/components/common/TextArea.vue'
|
||||||
import { Icon } from '@/components/icons'
|
import { Icon } from '@/components/icons'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
@@ -195,6 +234,11 @@ interface OutputLine {
|
|||||||
class: string
|
class: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface PreviewImage {
|
||||||
|
url: string
|
||||||
|
mimeType?: string
|
||||||
|
}
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
show: boolean
|
show: boolean
|
||||||
account: Account | null
|
account: Account | null
|
||||||
@@ -211,15 +255,37 @@ const streamingContent = ref('')
|
|||||||
const errorMessage = ref('')
|
const errorMessage = ref('')
|
||||||
const availableModels = ref<ClaudeModel[]>([])
|
const availableModels = ref<ClaudeModel[]>([])
|
||||||
const selectedModelId = ref('')
|
const selectedModelId = ref('')
|
||||||
|
const testPrompt = ref('')
|
||||||
const loadingModels = ref(false)
|
const loadingModels = ref(false)
|
||||||
let eventSource: EventSource | null = null
|
let eventSource: EventSource | null = null
|
||||||
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
const isSoraAccount = computed(() => props.account?.platform === 'sora')
|
||||||
|
const generatedImages = ref<PreviewImage[]>([])
|
||||||
|
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
|
||||||
|
const supportsGeminiImageTest = computed(() => {
|
||||||
|
if (isSoraAccount.value) return false
|
||||||
|
const modelID = selectedModelId.value.toLowerCase()
|
||||||
|
if (!modelID.startsWith('gemini-') || !modelID.includes('-image')) return false
|
||||||
|
|
||||||
|
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
|
||||||
|
})
|
||||||
|
|
||||||
|
const sortTestModels = (models: ClaudeModel[]) => {
|
||||||
|
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
|
||||||
|
|
||||||
|
return [...models].sort((a, b) => {
|
||||||
|
const aPriority = priorityMap.get(a.id) ?? Number.MAX_SAFE_INTEGER
|
||||||
|
const bPriority = priorityMap.get(b.id) ?? Number.MAX_SAFE_INTEGER
|
||||||
|
if (aPriority !== bPriority) return aPriority - bPriority
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Load available models when modal opens
|
// Load available models when modal opens
|
||||||
watch(
|
watch(
|
||||||
() => props.show,
|
() => props.show,
|
||||||
async (newVal) => {
|
async (newVal) => {
|
||||||
if (newVal && props.account) {
|
if (newVal && props.account) {
|
||||||
|
testPrompt.value = ''
|
||||||
resetState()
|
resetState()
|
||||||
await loadAvailableModels()
|
await loadAvailableModels()
|
||||||
} else {
|
} else {
|
||||||
@@ -228,6 +294,12 @@ watch(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
watch(selectedModelId, () => {
|
||||||
|
if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
|
||||||
|
testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const loadAvailableModels = async () => {
|
const loadAvailableModels = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
if (props.account.platform === 'sora') {
|
if (props.account.platform === 'sora') {
|
||||||
@@ -240,17 +312,14 @@ const loadAvailableModels = async () => {
|
|||||||
loadingModels.value = true
|
loadingModels.value = true
|
||||||
selectedModelId.value = '' // Reset selection before loading
|
selectedModelId.value = '' // Reset selection before loading
|
||||||
try {
|
try {
|
||||||
availableModels.value = await adminAPI.accounts.getAvailableModels(props.account.id)
|
const models = await adminAPI.accounts.getAvailableModels(props.account.id)
|
||||||
|
availableModels.value = props.account.platform === 'gemini' || props.account.platform === 'antigravity'
|
||||||
|
? sortTestModels(models)
|
||||||
|
: models
|
||||||
// Default selection by platform
|
// Default selection by platform
|
||||||
if (availableModels.value.length > 0) {
|
if (availableModels.value.length > 0) {
|
||||||
if (props.account.platform === 'gemini') {
|
if (props.account.platform === 'gemini') {
|
||||||
const preferred =
|
selectedModelId.value = availableModels.value[0].id
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.0-flash') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.5-flash') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-2.5-pro') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
|
|
||||||
availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
|
|
||||||
selectedModelId.value = preferred?.id || availableModels.value[0].id
|
|
||||||
} else {
|
} else {
|
||||||
// Try to select Sonnet as default, otherwise use first model
|
// Try to select Sonnet as default, otherwise use first model
|
||||||
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
||||||
@@ -272,6 +341,7 @@ const resetState = () => {
|
|||||||
outputLines.value = []
|
outputLines.value = []
|
||||||
streamingContent.value = ''
|
streamingContent.value = ''
|
||||||
errorMessage.value = ''
|
errorMessage.value = ''
|
||||||
|
generatedImages.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
@@ -325,7 +395,12 @@ const startTest = async () => {
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify(
|
body: JSON.stringify(
|
||||||
isSoraAccount.value ? {} : { model_id: selectedModelId.value }
|
isSoraAccount.value
|
||||||
|
? {}
|
||||||
|
: {
|
||||||
|
model_id: selectedModelId.value,
|
||||||
|
prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
|
||||||
|
}
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -376,6 +451,8 @@ const handleEvent = (event: {
|
|||||||
model?: string
|
model?: string
|
||||||
success?: boolean
|
success?: boolean
|
||||||
error?: string
|
error?: string
|
||||||
|
image_url?: string
|
||||||
|
mime_type?: string
|
||||||
}) => {
|
}) => {
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case 'test_start':
|
case 'test_start':
|
||||||
@@ -384,7 +461,11 @@ const handleEvent = (event: {
|
|||||||
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
|
||||||
}
|
}
|
||||||
addLine(
|
addLine(
|
||||||
isSoraAccount.value ? t('admin.accounts.soraTestingFlow') : t('admin.accounts.sendingTestMessage'),
|
isSoraAccount.value
|
||||||
|
? t('admin.accounts.soraTestingFlow')
|
||||||
|
: supportsGeminiImageTest.value
|
||||||
|
? t('admin.accounts.sendingGeminiImageRequest')
|
||||||
|
: t('admin.accounts.sendingTestMessage'),
|
||||||
'text-gray-400'
|
'text-gray-400'
|
||||||
)
|
)
|
||||||
addLine('', 'text-gray-300')
|
addLine('', 'text-gray-300')
|
||||||
@@ -398,6 +479,16 @@ const handleEvent = (event: {
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
|
case 'image':
|
||||||
|
if (event.image_url) {
|
||||||
|
generatedImages.value.push({
|
||||||
|
url: event.image_url,
|
||||||
|
mimeType: event.mime_type
|
||||||
|
})
|
||||||
|
addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
case 'test_complete':
|
case 'test_complete':
|
||||||
// Move streaming content to output lines
|
// Move streaming content to output lines
|
||||||
if (streamingContent.value) {
|
if (streamingContent.value) {
|
||||||
|
|||||||
@@ -0,0 +1,147 @@
|
|||||||
|
import { flushPromises, mount } from '@vue/test-utils'
|
||||||
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import AccountTestModal from '../AccountTestModal.vue'
|
||||||
|
|
||||||
|
const { getAvailableModels, copyToClipboard } = vi.hoisted(() => ({
|
||||||
|
getAvailableModels: vi.fn(),
|
||||||
|
copyToClipboard: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/admin', () => ({
|
||||||
|
adminAPI: {
|
||||||
|
accounts: {
|
||||||
|
getAvailableModels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/composables/useClipboard', () => ({
|
||||||
|
useClipboard: () => ({
|
||||||
|
copyToClipboard
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'admin.accounts.geminiImagePromptDefault': 'Generate a cute orange cat astronaut sticker on a clean pastel background.'
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string, params?: Record<string, string | number>) => {
|
||||||
|
if (key === 'admin.accounts.geminiImageReceived' && params?.count) {
|
||||||
|
return `received-${params.count}`
|
||||||
|
}
|
||||||
|
return messages[key] || key
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
function createStreamResponse(lines: string[]) {
|
||||||
|
const encoder = new TextEncoder()
|
||||||
|
const chunks = lines.map((line) => encoder.encode(line))
|
||||||
|
let index = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
ok: true,
|
||||||
|
body: {
|
||||||
|
getReader: () => ({
|
||||||
|
read: vi.fn().mockImplementation(async () => {
|
||||||
|
if (index < chunks.length) {
|
||||||
|
return { done: false, value: chunks[index++] }
|
||||||
|
}
|
||||||
|
return { done: true, value: undefined }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} as Response
|
||||||
|
}
|
||||||
|
|
||||||
|
function mountModal() {
|
||||||
|
return mount(AccountTestModal, {
|
||||||
|
props: {
|
||||||
|
show: false,
|
||||||
|
account: {
|
||||||
|
id: 42,
|
||||||
|
name: 'Gemini Image Test',
|
||||||
|
platform: 'gemini',
|
||||||
|
type: 'apikey',
|
||||||
|
status: 'active'
|
||||||
|
}
|
||||||
|
} as any,
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
BaseDialog: { template: '<div><slot /><slot name="footer" /></div>' },
|
||||||
|
Select: { template: '<div class="select-stub"></div>' },
|
||||||
|
TextArea: {
|
||||||
|
props: ['modelValue'],
|
||||||
|
emits: ['update:modelValue'],
|
||||||
|
template: '<textarea class="textarea-stub" :value="modelValue" @input="$emit(\'update:modelValue\', $event.target.value)" />'
|
||||||
|
},
|
||||||
|
Icon: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('AccountTestModal', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
getAvailableModels.mockResolvedValue([
|
||||||
|
{ id: 'gemini-2.0-flash', display_name: 'Gemini 2.0 Flash' },
|
||||||
|
{ id: 'gemini-2.5-flash-image', display_name: 'Gemini 2.5 Flash Image' },
|
||||||
|
{ id: 'gemini-3.1-flash-image', display_name: 'Gemini 3.1 Flash Image' }
|
||||||
|
])
|
||||||
|
copyToClipboard.mockReset()
|
||||||
|
Object.defineProperty(globalThis, 'localStorage', {
|
||||||
|
value: {
|
||||||
|
getItem: vi.fn((key: string) => (key === 'auth_token' ? 'test-token' : null)),
|
||||||
|
setItem: vi.fn(),
|
||||||
|
removeItem: vi.fn(),
|
||||||
|
clear: vi.fn()
|
||||||
|
},
|
||||||
|
configurable: true
|
||||||
|
})
|
||||||
|
global.fetch = vi.fn().mockResolvedValue(
|
||||||
|
createStreamResponse([
|
||||||
|
'data: {"type":"test_start","model":"gemini-2.5-flash-image"}\n',
|
||||||
|
'data: {"type":"image","image_url":"data:image/png;base64,QUJD","mime_type":"image/png"}\n',
|
||||||
|
'data: {"type":"test_complete","success":true}\n'
|
||||||
|
])
|
||||||
|
) as any
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.restoreAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('gemini 图片模型测试会携带提示词并渲染图片预览', async () => {
|
||||||
|
const wrapper = mountModal()
|
||||||
|
await wrapper.setProps({ show: true })
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const promptInput = wrapper.find('textarea.textarea-stub')
|
||||||
|
expect(promptInput.exists()).toBe(true)
|
||||||
|
await promptInput.setValue('draw a tiny orange cat astronaut')
|
||||||
|
|
||||||
|
const buttons = wrapper.findAll('button')
|
||||||
|
const startButton = buttons.find((button) => button.text().includes('admin.accounts.startTest'))
|
||||||
|
expect(startButton).toBeTruthy()
|
||||||
|
|
||||||
|
await startButton!.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(global.fetch).toHaveBeenCalledTimes(1)
|
||||||
|
const [, request] = (global.fetch as any).mock.calls[0]
|
||||||
|
expect(JSON.parse(request.body)).toEqual({
|
||||||
|
model_id: 'gemini-3.1-flash-image',
|
||||||
|
prompt: 'draw a tiny orange cat astronaut'
|
||||||
|
})
|
||||||
|
|
||||||
|
const preview = wrapper.find('img[alt="gemini-test-image-1"]')
|
||||||
|
expect(preview.exists()).toBe(true)
|
||||||
|
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,12 +1,39 @@
|
|||||||
<template>
|
<template>
|
||||||
<div class="card p-4">
|
<div class="card p-4">
|
||||||
<h3 class="mb-4 text-sm font-semibold text-gray-900 dark:text-white">
|
<div class="mb-4 flex items-center justify-between gap-3">
|
||||||
{{ t('admin.dashboard.groupDistribution') }}
|
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
|
||||||
</h3>
|
{{ t('admin.dashboard.groupDistribution') }}
|
||||||
|
</h3>
|
||||||
|
<div
|
||||||
|
v-if="showMetricToggle"
|
||||||
|
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||||
|
:class="metric === 'tokens'
|
||||||
|
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||||
|
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||||
|
@click="emit('update:metric', 'tokens')"
|
||||||
|
>
|
||||||
|
{{ t('admin.dashboard.metricTokens') }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||||
|
:class="metric === 'actual_cost'
|
||||||
|
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||||
|
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||||
|
@click="emit('update:metric', 'actual_cost')"
|
||||||
|
>
|
||||||
|
{{ t('admin.dashboard.metricActualCost') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div v-if="loading" class="flex h-48 items-center justify-center">
|
<div v-if="loading" class="flex h-48 items-center justify-center">
|
||||||
<LoadingSpinner />
|
<LoadingSpinner />
|
||||||
</div>
|
</div>
|
||||||
<div v-else-if="groupStats.length > 0 && chartData" class="flex items-center gap-6">
|
<div v-else-if="displayGroupStats.length > 0 && chartData" class="flex items-center gap-6">
|
||||||
<div class="h-48 w-48">
|
<div class="h-48 w-48">
|
||||||
<Doughnut :data="chartData" :options="doughnutOptions" />
|
<Doughnut :data="chartData" :options="doughnutOptions" />
|
||||||
</div>
|
</div>
|
||||||
@@ -23,7 +50,7 @@
|
|||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
<tr
|
<tr
|
||||||
v-for="group in groupStats"
|
v-for="group in displayGroupStats"
|
||||||
:key="group.group_id"
|
:key="group.group_id"
|
||||||
class="border-t border-gray-100 dark:border-gray-700"
|
class="border-t border-gray-100 dark:border-gray-700"
|
||||||
>
|
>
|
||||||
@@ -71,9 +98,21 @@ ChartJS.register(ArcElement, Tooltip, Legend)
|
|||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
const props = defineProps<{
|
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||||
|
|
||||||
|
const props = withDefaults(defineProps<{
|
||||||
groupStats: GroupStat[]
|
groupStats: GroupStat[]
|
||||||
loading?: boolean
|
loading?: boolean
|
||||||
|
metric?: DistributionMetric
|
||||||
|
showMetricToggle?: boolean
|
||||||
|
}>(), {
|
||||||
|
loading: false,
|
||||||
|
metric: 'tokens',
|
||||||
|
showMetricToggle: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
'update:metric': [value: DistributionMetric]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const chartColors = [
|
const chartColors = [
|
||||||
@@ -89,15 +128,22 @@ const chartColors = [
|
|||||||
'#84cc16'
|
'#84cc16'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
const displayGroupStats = computed(() => {
|
||||||
|
if (!props.groupStats?.length) return []
|
||||||
|
|
||||||
|
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
|
||||||
|
return [...props.groupStats].sort((a, b) => b[metricKey] - a[metricKey])
|
||||||
|
})
|
||||||
|
|
||||||
const chartData = computed(() => {
|
const chartData = computed(() => {
|
||||||
if (!props.groupStats?.length) return null
|
if (!props.groupStats?.length) return null
|
||||||
|
|
||||||
return {
|
return {
|
||||||
labels: props.groupStats.map((g) => g.group_name || String(g.group_id)),
|
labels: displayGroupStats.value.map((g) => g.group_name || String(g.group_id)),
|
||||||
datasets: [
|
datasets: [
|
||||||
{
|
{
|
||||||
data: props.groupStats.map((g) => g.total_tokens),
|
data: displayGroupStats.value.map((g) => props.metric === 'actual_cost' ? g.actual_cost : g.total_tokens),
|
||||||
backgroundColor: chartColors.slice(0, props.groupStats.length),
|
backgroundColor: chartColors.slice(0, displayGroupStats.value.length),
|
||||||
borderWidth: 0
|
borderWidth: 0
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -116,8 +162,11 @@ const doughnutOptions = computed(() => ({
|
|||||||
label: (context: any) => {
|
label: (context: any) => {
|
||||||
const value = context.raw as number
|
const value = context.raw as number
|
||||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||||
const percentage = ((value / total) * 100).toFixed(1)
|
const percentage = total > 0 ? ((value / total) * 100).toFixed(1) : '0.0'
|
||||||
return `${context.label}: ${formatTokens(value)} (${percentage}%)`
|
const formattedValue = props.metric === 'actual_cost'
|
||||||
|
? `$${formatCost(value)}`
|
||||||
|
: formatTokens(value)
|
||||||
|
return `${context.label}: ${formattedValue} (${percentage}%)`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,39 @@
|
|||||||
<template>
|
<template>
|
||||||
<div class="card p-4">
|
<div class="card p-4">
|
||||||
<h3 class="mb-4 text-sm font-semibold text-gray-900 dark:text-white">
|
<div class="mb-4 flex items-center justify-between gap-3">
|
||||||
{{ t('admin.dashboard.modelDistribution') }}
|
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
|
||||||
</h3>
|
{{ t('admin.dashboard.modelDistribution') }}
|
||||||
|
</h3>
|
||||||
|
<div
|
||||||
|
v-if="showMetricToggle"
|
||||||
|
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||||
|
:class="metric === 'tokens'
|
||||||
|
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||||
|
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||||
|
@click="emit('update:metric', 'tokens')"
|
||||||
|
>
|
||||||
|
{{ t('admin.dashboard.metricTokens') }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||||
|
:class="metric === 'actual_cost'
|
||||||
|
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||||
|
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||||
|
@click="emit('update:metric', 'actual_cost')"
|
||||||
|
>
|
||||||
|
{{ t('admin.dashboard.metricActualCost') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div v-if="loading" class="flex h-48 items-center justify-center">
|
<div v-if="loading" class="flex h-48 items-center justify-center">
|
||||||
<LoadingSpinner />
|
<LoadingSpinner />
|
||||||
</div>
|
</div>
|
||||||
<div v-else-if="modelStats.length > 0 && chartData" class="flex items-center gap-6">
|
<div v-else-if="displayModelStats.length > 0 && chartData" class="flex items-center gap-6">
|
||||||
<div class="h-48 w-48">
|
<div class="h-48 w-48">
|
||||||
<Doughnut :data="chartData" :options="doughnutOptions" />
|
<Doughnut :data="chartData" :options="doughnutOptions" />
|
||||||
</div>
|
</div>
|
||||||
@@ -23,7 +50,7 @@
|
|||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
<tr
|
<tr
|
||||||
v-for="model in modelStats"
|
v-for="model in displayModelStats"
|
||||||
:key="model.model"
|
:key="model.model"
|
||||||
class="border-t border-gray-100 dark:border-gray-700"
|
class="border-t border-gray-100 dark:border-gray-700"
|
||||||
>
|
>
|
||||||
@@ -71,9 +98,21 @@ ChartJS.register(ArcElement, Tooltip, Legend)
|
|||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
const props = defineProps<{
|
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||||
|
|
||||||
|
const props = withDefaults(defineProps<{
|
||||||
modelStats: ModelStat[]
|
modelStats: ModelStat[]
|
||||||
loading?: boolean
|
loading?: boolean
|
||||||
|
metric?: DistributionMetric
|
||||||
|
showMetricToggle?: boolean
|
||||||
|
}>(), {
|
||||||
|
loading: false,
|
||||||
|
metric: 'tokens',
|
||||||
|
showMetricToggle: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
'update:metric': [value: DistributionMetric]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const chartColors = [
|
const chartColors = [
|
||||||
@@ -89,15 +128,22 @@ const chartColors = [
|
|||||||
'#84cc16'
|
'#84cc16'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
const displayModelStats = computed(() => {
|
||||||
|
if (!props.modelStats?.length) return []
|
||||||
|
|
||||||
|
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
|
||||||
|
return [...props.modelStats].sort((a, b) => b[metricKey] - a[metricKey])
|
||||||
|
})
|
||||||
|
|
||||||
const chartData = computed(() => {
|
const chartData = computed(() => {
|
||||||
if (!props.modelStats?.length) return null
|
if (!props.modelStats?.length) return null
|
||||||
|
|
||||||
return {
|
return {
|
||||||
labels: props.modelStats.map((m) => m.model),
|
labels: displayModelStats.value.map((m) => m.model),
|
||||||
datasets: [
|
datasets: [
|
||||||
{
|
{
|
||||||
data: props.modelStats.map((m) => m.total_tokens),
|
data: displayModelStats.value.map((m) => props.metric === 'actual_cost' ? m.actual_cost : m.total_tokens),
|
||||||
backgroundColor: chartColors.slice(0, props.modelStats.length),
|
backgroundColor: chartColors.slice(0, displayModelStats.value.length),
|
||||||
borderWidth: 0
|
borderWidth: 0
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -116,8 +162,11 @@ const doughnutOptions = computed(() => ({
|
|||||||
label: (context: any) => {
|
label: (context: any) => {
|
||||||
const value = context.raw as number
|
const value = context.raw as number
|
||||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||||
const percentage = ((value / total) * 100).toFixed(1)
|
const percentage = total > 0 ? ((value / total) * 100).toFixed(1) : '0.0'
|
||||||
return `${context.label}: ${formatTokens(value)} (${percentage}%)`
|
const formattedValue = props.metric === 'actual_cost'
|
||||||
|
? `$${formatCost(value)}`
|
||||||
|
: formatTokens(value)
|
||||||
|
return `${context.label}: ${formattedValue} (${percentage}%)`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,114 @@
|
|||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
import { mount } from '@vue/test-utils'
|
||||||
|
|
||||||
|
import GroupDistributionChart from '../GroupDistributionChart.vue'
|
||||||
|
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'admin.dashboard.groupDistribution': 'Group Distribution',
|
||||||
|
'admin.dashboard.group': 'Group',
|
||||||
|
'admin.dashboard.noGroup': 'No Group',
|
||||||
|
'admin.dashboard.requests': 'Requests',
|
||||||
|
'admin.dashboard.tokens': 'Tokens',
|
||||||
|
'admin.dashboard.actual': 'Actual',
|
||||||
|
'admin.dashboard.standard': 'Standard',
|
||||||
|
'admin.dashboard.metricTokens': 'By Tokens',
|
||||||
|
'admin.dashboard.metricActualCost': 'By Actual Cost',
|
||||||
|
'admin.dashboard.noDataAvailable': 'No data available',
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => messages[key] ?? key,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('vue-chartjs', () => ({
|
||||||
|
Doughnut: {
|
||||||
|
props: ['data'],
|
||||||
|
template: '<div class="chart-data">{{ JSON.stringify(data) }}</div>',
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('GroupDistributionChart', () => {
|
||||||
|
const groupStats = [
|
||||||
|
{
|
||||||
|
group_id: 1,
|
||||||
|
group_name: 'group-a',
|
||||||
|
requests: 9,
|
||||||
|
total_tokens: 1200,
|
||||||
|
cost: 1.8,
|
||||||
|
actual_cost: 0.1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
group_id: 2,
|
||||||
|
group_name: 'group-b',
|
||||||
|
requests: 4,
|
||||||
|
total_tokens: 600,
|
||||||
|
cost: 0.7,
|
||||||
|
actual_cost: 0.9,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
it('uses total_tokens and token ordering by default', () => {
|
||||||
|
const wrapper = mount(GroupDistributionChart, {
|
||||||
|
props: {
|
||||||
|
groupStats,
|
||||||
|
},
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
LoadingSpinner: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const chartData = JSON.parse(wrapper.find('.chart-data').text())
|
||||||
|
expect(chartData.labels).toEqual(['group-a', 'group-b'])
|
||||||
|
expect(chartData.datasets[0].data).toEqual([1200, 600])
|
||||||
|
|
||||||
|
const rows = wrapper.findAll('tbody tr')
|
||||||
|
expect(rows[0].text()).toContain('group-a')
|
||||||
|
expect(rows[1].text()).toContain('group-b')
|
||||||
|
|
||||||
|
const options = (wrapper.vm as any).$?.setupState.doughnutOptions
|
||||||
|
const label = options.plugins.tooltip.callbacks.label({
|
||||||
|
label: 'group-a',
|
||||||
|
raw: 1200,
|
||||||
|
dataset: { data: [1200, 600] },
|
||||||
|
})
|
||||||
|
expect(label).toBe('group-a: 1.20K (66.7%)')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('uses actual_cost and reorders rows in actual cost mode', () => {
|
||||||
|
const wrapper = mount(GroupDistributionChart, {
|
||||||
|
props: {
|
||||||
|
groupStats,
|
||||||
|
metric: 'actual_cost',
|
||||||
|
},
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
LoadingSpinner: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const chartData = JSON.parse(wrapper.find('.chart-data').text())
|
||||||
|
expect(chartData.labels).toEqual(['group-b', 'group-a'])
|
||||||
|
expect(chartData.datasets[0].data).toEqual([0.9, 0.1])
|
||||||
|
|
||||||
|
const rows = wrapper.findAll('tbody tr')
|
||||||
|
expect(rows[0].text()).toContain('group-b')
|
||||||
|
expect(rows[1].text()).toContain('group-a')
|
||||||
|
|
||||||
|
const options = (wrapper.vm as any).$?.setupState.doughnutOptions
|
||||||
|
const label = options.plugins.tooltip.callbacks.label({
|
||||||
|
label: 'group-b',
|
||||||
|
raw: 0.9,
|
||||||
|
dataset: { data: [0.9, 0.1] },
|
||||||
|
})
|
||||||
|
expect(label).toBe('group-b: $0.900 (90.0%)')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
import { mount } from '@vue/test-utils'
|
||||||
|
|
||||||
|
import ModelDistributionChart from '../ModelDistributionChart.vue'
|
||||||
|
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'admin.dashboard.modelDistribution': 'Model Distribution',
|
||||||
|
'admin.dashboard.model': 'Model',
|
||||||
|
'admin.dashboard.requests': 'Requests',
|
||||||
|
'admin.dashboard.tokens': 'Tokens',
|
||||||
|
'admin.dashboard.actual': 'Actual',
|
||||||
|
'admin.dashboard.standard': 'Standard',
|
||||||
|
'admin.dashboard.metricTokens': 'By Tokens',
|
||||||
|
'admin.dashboard.metricActualCost': 'By Actual Cost',
|
||||||
|
'admin.dashboard.noDataAvailable': 'No data available',
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => messages[key] ?? key,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('vue-chartjs', () => ({
|
||||||
|
Doughnut: {
|
||||||
|
props: ['data'],
|
||||||
|
template: '<div class="chart-data">{{ JSON.stringify(data) }}</div>',
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('ModelDistributionChart', () => {
|
||||||
|
const modelStats = [
|
||||||
|
{
|
||||||
|
model: 'model-a',
|
||||||
|
requests: 8,
|
||||||
|
input_tokens: 100,
|
||||||
|
output_tokens: 50,
|
||||||
|
cache_creation_tokens: 0,
|
||||||
|
cache_read_tokens: 0,
|
||||||
|
total_tokens: 1000,
|
||||||
|
cost: 1.5,
|
||||||
|
actual_cost: 0.2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model: 'model-b',
|
||||||
|
requests: 3,
|
||||||
|
input_tokens: 40,
|
||||||
|
output_tokens: 20,
|
||||||
|
cache_creation_tokens: 0,
|
||||||
|
cache_read_tokens: 0,
|
||||||
|
total_tokens: 500,
|
||||||
|
cost: 0.5,
|
||||||
|
actual_cost: 1.4,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
it('uses total_tokens and token ordering by default', () => {
|
||||||
|
const wrapper = mount(ModelDistributionChart, {
|
||||||
|
props: {
|
||||||
|
modelStats,
|
||||||
|
},
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
LoadingSpinner: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const chartData = JSON.parse(wrapper.find('.chart-data').text())
|
||||||
|
expect(chartData.labels).toEqual(['model-a', 'model-b'])
|
||||||
|
expect(chartData.datasets[0].data).toEqual([1000, 500])
|
||||||
|
|
||||||
|
const rows = wrapper.findAll('tbody tr')
|
||||||
|
expect(rows[0].text()).toContain('model-a')
|
||||||
|
expect(rows[1].text()).toContain('model-b')
|
||||||
|
|
||||||
|
const options = (wrapper.vm as any).$?.setupState.doughnutOptions
|
||||||
|
const label = options.plugins.tooltip.callbacks.label({
|
||||||
|
label: 'model-a',
|
||||||
|
raw: 1000,
|
||||||
|
dataset: { data: [1000, 500] },
|
||||||
|
})
|
||||||
|
expect(label).toBe('model-a: 1.00K (66.7%)')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('uses actual_cost and reorders rows in actual cost mode', () => {
|
||||||
|
const wrapper = mount(ModelDistributionChart, {
|
||||||
|
props: {
|
||||||
|
modelStats,
|
||||||
|
metric: 'actual_cost',
|
||||||
|
},
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
LoadingSpinner: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const chartData = JSON.parse(wrapper.find('.chart-data').text())
|
||||||
|
expect(chartData.labels).toEqual(['model-b', 'model-a'])
|
||||||
|
expect(chartData.datasets[0].data).toEqual([1.4, 0.2])
|
||||||
|
|
||||||
|
const rows = wrapper.findAll('tbody tr')
|
||||||
|
expect(rows[0].text()).toContain('model-b')
|
||||||
|
expect(rows[1].text()).toContain('model-a')
|
||||||
|
|
||||||
|
const options = (wrapper.vm as any).$?.setupState.doughnutOptions
|
||||||
|
const label = options.plugins.tooltip.callbacks.label({
|
||||||
|
label: 'model-b',
|
||||||
|
raw: 1.4,
|
||||||
|
dataset: { data: [1.4, 0.2] },
|
||||||
|
})
|
||||||
|
expect(label).toBe('model-b: $1.40 (87.5%)')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -959,6 +959,23 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
'gemini-2.5-flash-image': {
|
||||||
|
name: 'Gemini 2.5 Flash Image',
|
||||||
|
limit: {
|
||||||
|
context: 1048576,
|
||||||
|
output: 65536
|
||||||
|
},
|
||||||
|
modalities: {
|
||||||
|
input: ['text', 'image'],
|
||||||
|
output: ['image']
|
||||||
|
},
|
||||||
|
options: {
|
||||||
|
thinking: {
|
||||||
|
budgetTokens: 24576,
|
||||||
|
type: 'enabled'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
'gemini-3.1-flash-image': {
|
'gemini-3.1-flash-image': {
|
||||||
name: 'Gemini 3.1 Flash Image',
|
name: 'Gemini 3.1 Flash Image',
|
||||||
limit: {
|
limit: {
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
vi.mock('@/api/admin/accounts', () => ({
|
||||||
|
getAntigravityDefaultModelMapping: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
|
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
|
||||||
|
|
||||||
describe('useModelWhitelist', () => {
|
describe('useModelWhitelist', () => {
|
||||||
@@ -12,10 +17,27 @@ describe('useModelWhitelist', () => {
|
|||||||
it('antigravity 模型列表包含图片模型兼容项', () => {
|
it('antigravity 模型列表包含图片模型兼容项', () => {
|
||||||
const models = getModelsByPlatform('antigravity')
|
const models = getModelsByPlatform('antigravity')
|
||||||
|
|
||||||
|
expect(models).toContain('gemini-2.5-flash-image')
|
||||||
expect(models).toContain('gemini-3.1-flash-image')
|
expect(models).toContain('gemini-3.1-flash-image')
|
||||||
expect(models).toContain('gemini-3-pro-image')
|
expect(models).toContain('gemini-3-pro-image')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('gemini 模型列表包含原生生图模型', () => {
|
||||||
|
const models = getModelsByPlatform('gemini')
|
||||||
|
|
||||||
|
expect(models).toContain('gemini-2.5-flash-image')
|
||||||
|
expect(models).toContain('gemini-3.1-flash-image')
|
||||||
|
expect(models.indexOf('gemini-3.1-flash-image')).toBeLessThan(models.indexOf('gemini-2.0-flash'))
|
||||||
|
expect(models.indexOf('gemini-2.5-flash-image')).toBeLessThan(models.indexOf('gemini-2.5-flash'))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('antigravity 模型列表会把新的 Gemini 图片模型排在前面', () => {
|
||||||
|
const models = getModelsByPlatform('antigravity')
|
||||||
|
|
||||||
|
expect(models.indexOf('gemini-3.1-flash-image')).toBeLessThan(models.indexOf('gemini-2.5-flash'))
|
||||||
|
expect(models.indexOf('gemini-2.5-flash-image')).toBeLessThan(models.indexOf('gemini-2.5-flash-lite'))
|
||||||
|
})
|
||||||
|
|
||||||
it('whitelist 模式会忽略通配符条目', () => {
|
it('whitelist 模式会忽略通配符条目', () => {
|
||||||
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
|
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
|
||||||
expect(mapping).toEqual({
|
expect(mapping).toEqual({
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ export const claudeModels = [
|
|||||||
const geminiModels = [
|
const geminiModels = [
|
||||||
// Keep in sync with backend curated Gemini lists.
|
// Keep in sync with backend curated Gemini lists.
|
||||||
// This list is intentionally conservative (models commonly available across OAuth/API key).
|
// This list is intentionally conservative (models commonly available across OAuth/API key).
|
||||||
|
'gemini-3.1-flash-image',
|
||||||
|
'gemini-2.5-flash-image',
|
||||||
'gemini-2.0-flash',
|
'gemini-2.0-flash',
|
||||||
'gemini-2.5-flash',
|
'gemini-2.5-flash',
|
||||||
'gemini-2.5-pro',
|
'gemini-2.5-pro',
|
||||||
@@ -85,6 +87,8 @@ const antigravityModels = [
|
|||||||
'claude-sonnet-4-5',
|
'claude-sonnet-4-5',
|
||||||
'claude-sonnet-4-5-thinking',
|
'claude-sonnet-4-5-thinking',
|
||||||
// Gemini 2.5 系列
|
// Gemini 2.5 系列
|
||||||
|
'gemini-3.1-flash-image',
|
||||||
|
'gemini-2.5-flash-image',
|
||||||
'gemini-2.5-flash',
|
'gemini-2.5-flash',
|
||||||
'gemini-2.5-flash-lite',
|
'gemini-2.5-flash-lite',
|
||||||
'gemini-2.5-flash-thinking',
|
'gemini-2.5-flash-thinking',
|
||||||
@@ -96,7 +100,6 @@ const antigravityModels = [
|
|||||||
// Gemini 3.1 系列
|
// Gemini 3.1 系列
|
||||||
'gemini-3.1-pro-high',
|
'gemini-3.1-pro-high',
|
||||||
'gemini-3.1-pro-low',
|
'gemini-3.1-pro-low',
|
||||||
'gemini-3.1-flash-image',
|
|
||||||
'gemini-3-pro-image',
|
'gemini-3-pro-image',
|
||||||
// 其他
|
// 其他
|
||||||
'gpt-oss-120b-medium',
|
'gpt-oss-120b-medium',
|
||||||
@@ -291,7 +294,9 @@ const soraPresetMappings: { label: string; from: string; to: string; color: stri
|
|||||||
const geminiPresetMappings = [
|
const geminiPresetMappings = [
|
||||||
{ label: 'Flash 2.0', from: 'gemini-2.0-flash', to: 'gemini-2.0-flash', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
{ label: 'Flash 2.0', from: 'gemini-2.0-flash', to: 'gemini-2.0-flash', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
||||||
{ label: '2.5 Flash', from: 'gemini-2.5-flash', to: 'gemini-2.5-flash', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
{ label: '2.5 Flash', from: 'gemini-2.5-flash', to: 'gemini-2.5-flash', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||||
{ label: '2.5 Pro', from: 'gemini-2.5-pro', to: 'gemini-2.5-pro', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }
|
{ label: '2.5 Image', from: 'gemini-2.5-flash-image', to: 'gemini-2.5-flash-image', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||||
|
{ label: '2.5 Pro', from: 'gemini-2.5-pro', to: 'gemini-2.5-pro', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
|
{ label: '3.1 Image', from: 'gemini-3.1-flash-image', to: 'gemini-3.1-flash-image', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Antigravity 预设映射(支持通配符)
|
// Antigravity 预设映射(支持通配符)
|
||||||
@@ -314,6 +319,9 @@ const antigravityPresetMappings = [
|
|||||||
// Gemini 通配符映射
|
// Gemini 通配符映射
|
||||||
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||||
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||||
|
{ label: '2.5-Flash-Image透传', from: 'gemini-2.5-flash-image', to: 'gemini-2.5-flash-image', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||||
|
{ label: '3.1-Flash-Image透传', from: 'gemini-3.1-flash-image', to: 'gemini-3.1-flash-image', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||||
|
{ label: '3-Pro-Image→3.1', from: 'gemini-3-pro-image', to: 'gemini-3.1-flash-image', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||||
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
|
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
|
||||||
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||||
// 精确映射
|
// 精确映射
|
||||||
|
|||||||
@@ -950,6 +950,8 @@ export default {
|
|||||||
hour: 'Hour',
|
hour: 'Hour',
|
||||||
modelDistribution: 'Model Distribution',
|
modelDistribution: 'Model Distribution',
|
||||||
groupDistribution: 'Group Usage Distribution',
|
groupDistribution: 'Group Usage Distribution',
|
||||||
|
metricTokens: 'By Tokens',
|
||||||
|
metricActualCost: 'By Actual Cost',
|
||||||
tokenUsageTrend: 'Token Usage Trend',
|
tokenUsageTrend: 'Token Usage Trend',
|
||||||
userUsageTrend: 'User Usage Trend (Top 12)',
|
userUsageTrend: 'User Usage Trend (Top 12)',
|
||||||
model: 'Model',
|
model: 'Model',
|
||||||
@@ -1570,6 +1572,11 @@ export default {
|
|||||||
adjust: 'Adjust',
|
adjust: 'Adjust',
|
||||||
adjusting: 'Adjusting...',
|
adjusting: 'Adjusting...',
|
||||||
revoke: 'Revoke',
|
revoke: 'Revoke',
|
||||||
|
resetQuota: 'Reset Quota',
|
||||||
|
resetQuotaTitle: 'Reset Usage Quota',
|
||||||
|
resetQuotaConfirm: "Reset the daily and weekly usage quota for '{user}'? Usage will be zeroed and windows restarted from today.",
|
||||||
|
quotaResetSuccess: 'Quota reset successfully',
|
||||||
|
failedToResetQuota: 'Failed to reset quota',
|
||||||
noSubscriptionsYet: 'No subscriptions yet',
|
noSubscriptionsYet: 'No subscriptions yet',
|
||||||
assignFirstSubscription: 'Assign a subscription to get started.',
|
assignFirstSubscription: 'Assign a subscription to get started.',
|
||||||
subscriptionAssigned: 'Subscription assigned successfully',
|
subscriptionAssigned: 'Subscription assigned successfully',
|
||||||
@@ -2411,6 +2418,7 @@ export default {
|
|||||||
connectedToApi: 'Connected to API',
|
connectedToApi: 'Connected to API',
|
||||||
usingModel: 'Using model: {model}',
|
usingModel: 'Using model: {model}',
|
||||||
sendingTestMessage: 'Sending test message: "hi"',
|
sendingTestMessage: 'Sending test message: "hi"',
|
||||||
|
sendingGeminiImageRequest: 'Sending Gemini image generation test request...',
|
||||||
response: 'Response:',
|
response: 'Response:',
|
||||||
startTest: 'Start Test',
|
startTest: 'Start Test',
|
||||||
testing: 'Testing...',
|
testing: 'Testing...',
|
||||||
@@ -2422,6 +2430,13 @@ export default {
|
|||||||
selectTestModel: 'Select Test Model',
|
selectTestModel: 'Select Test Model',
|
||||||
testModel: 'Test model',
|
testModel: 'Test model',
|
||||||
testPrompt: 'Prompt: "hi"',
|
testPrompt: 'Prompt: "hi"',
|
||||||
|
geminiImagePromptLabel: 'Image prompt',
|
||||||
|
geminiImagePromptPlaceholder: 'Example: Generate an orange cat astronaut sticker in pixel-art style on a solid background.',
|
||||||
|
geminiImagePromptDefault: 'Generate a cute orange cat astronaut sticker on a clean pastel background.',
|
||||||
|
geminiImageTestHint: 'When a Gemini image model is selected, this test sends a real image-generation request and previews the returned image below.',
|
||||||
|
geminiImageTestMode: 'Mode: Gemini image generation test',
|
||||||
|
geminiImagePreview: 'Generated images:',
|
||||||
|
geminiImageReceived: 'Received test image #{count}',
|
||||||
soraUpstreamBaseUrlHint: 'Upstream Sora service URL (another Sub2API instance or compatible API)',
|
soraUpstreamBaseUrlHint: 'Upstream Sora service URL (another Sub2API instance or compatible API)',
|
||||||
soraTestHint: 'Sora test runs connectivity and capability checks (/backend/me, subscription, Sora2 invite and remaining quota).',
|
soraTestHint: 'Sora test runs connectivity and capability checks (/backend/me, subscription, Sora2 invite and remaining quota).',
|
||||||
soraTestTarget: 'Target: Sora account capability',
|
soraTestTarget: 'Target: Sora account capability',
|
||||||
|
|||||||
@@ -963,6 +963,8 @@ export default {
|
|||||||
hour: '按小时',
|
hour: '按小时',
|
||||||
modelDistribution: '模型分布',
|
modelDistribution: '模型分布',
|
||||||
groupDistribution: '分组使用分布',
|
groupDistribution: '分组使用分布',
|
||||||
|
metricTokens: '按 Token',
|
||||||
|
metricActualCost: '按实际消费',
|
||||||
tokenUsageTrend: 'Token 使用趋势',
|
tokenUsageTrend: 'Token 使用趋势',
|
||||||
noDataAvailable: '暂无数据',
|
noDataAvailable: '暂无数据',
|
||||||
model: '模型',
|
model: '模型',
|
||||||
@@ -1658,6 +1660,11 @@ export default {
|
|||||||
adjust: '调整',
|
adjust: '调整',
|
||||||
adjusting: '调整中...',
|
adjusting: '调整中...',
|
||||||
revoke: '撤销',
|
revoke: '撤销',
|
||||||
|
resetQuota: '重置配额',
|
||||||
|
resetQuotaTitle: '重置用量配额',
|
||||||
|
resetQuotaConfirm: "确定要重置 '{user}' 的每日和每周用量配额吗?用量将归零并从今天开始重新计算。",
|
||||||
|
quotaResetSuccess: '配额重置成功',
|
||||||
|
failedToResetQuota: '重置配额失败',
|
||||||
noSubscriptionsYet: '暂无订阅',
|
noSubscriptionsYet: '暂无订阅',
|
||||||
assignFirstSubscription: '分配一个订阅以开始使用。',
|
assignFirstSubscription: '分配一个订阅以开始使用。',
|
||||||
subscriptionAssigned: '订阅分配成功',
|
subscriptionAssigned: '订阅分配成功',
|
||||||
@@ -2540,6 +2547,7 @@ export default {
|
|||||||
connectedToApi: '已连接到 API',
|
connectedToApi: '已连接到 API',
|
||||||
usingModel: '使用模型:{model}',
|
usingModel: '使用模型:{model}',
|
||||||
sendingTestMessage: '发送测试消息:"hi"',
|
sendingTestMessage: '发送测试消息:"hi"',
|
||||||
|
sendingGeminiImageRequest: '发送 Gemini 生图测试请求...',
|
||||||
response: '响应:',
|
response: '响应:',
|
||||||
startTest: '开始测试',
|
startTest: '开始测试',
|
||||||
retry: '重试',
|
retry: '重试',
|
||||||
@@ -2550,6 +2558,13 @@ export default {
|
|||||||
selectTestModel: '选择测试模型',
|
selectTestModel: '选择测试模型',
|
||||||
testModel: '测试模型',
|
testModel: '测试模型',
|
||||||
testPrompt: '提示词:"hi"',
|
testPrompt: '提示词:"hi"',
|
||||||
|
geminiImagePromptLabel: '生图提示词',
|
||||||
|
geminiImagePromptPlaceholder: '例如:生成一只戴宇航员头盔的橘猫,像素插画风格,纯色背景。',
|
||||||
|
geminiImagePromptDefault: 'Generate a cute orange cat astronaut sticker on a clean pastel background.',
|
||||||
|
geminiImageTestHint: '选择 Gemini 图片模型后,这里会直接发起生图测试,并在下方展示返回图片。',
|
||||||
|
geminiImageTestMode: '模式:Gemini 生图测试',
|
||||||
|
geminiImagePreview: '生成结果:',
|
||||||
|
geminiImageReceived: '已收到第 {count} 张测试图片',
|
||||||
soraUpstreamBaseUrlHint: '上游 Sora 服务地址(另一个 Sub2API 实例或兼容 API)',
|
soraUpstreamBaseUrlHint: '上游 Sora 服务地址(另一个 Sub2API 实例或兼容 API)',
|
||||||
soraTestHint: 'Sora 测试将执行连通性与能力检测(/backend/me、订阅信息、Sora2 邀请码与剩余额度)。',
|
soraTestHint: 'Sora 测试将执行连通性与能力检测(/backend/me、订阅信息、Sora2 邀请码与剩余额度)。',
|
||||||
soraTestTarget: '检测目标:Sora 账号能力',
|
soraTestTarget: '检测目标:Sora 账号能力',
|
||||||
|
|||||||
@@ -370,6 +370,15 @@
|
|||||||
<Icon name="calendar" size="sm" />
|
<Icon name="calendar" size="sm" />
|
||||||
<span class="text-xs">{{ t('admin.subscriptions.adjust') }}</span>
|
<span class="text-xs">{{ t('admin.subscriptions.adjust') }}</span>
|
||||||
</button>
|
</button>
|
||||||
|
<button
|
||||||
|
v-if="row.status === 'active'"
|
||||||
|
@click="handleResetQuota(row)"
|
||||||
|
:disabled="resettingQuota && resettingSubscription?.id === row.id"
|
||||||
|
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-orange-50 hover:text-orange-600 dark:hover:bg-orange-900/20 dark:hover:text-orange-400 disabled:cursor-not-allowed disabled:opacity-50"
|
||||||
|
>
|
||||||
|
<Icon name="refresh" size="sm" />
|
||||||
|
<span class="text-xs">{{ t('admin.subscriptions.resetQuota') }}</span>
|
||||||
|
</button>
|
||||||
<button
|
<button
|
||||||
v-if="row.status === 'active'"
|
v-if="row.status === 'active'"
|
||||||
@click="handleRevoke(row)"
|
@click="handleRevoke(row)"
|
||||||
@@ -618,6 +627,17 @@
|
|||||||
@confirm="confirmRevoke"
|
@confirm="confirmRevoke"
|
||||||
@cancel="showRevokeDialog = false"
|
@cancel="showRevokeDialog = false"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- Reset Quota Confirmation Dialog -->
|
||||||
|
<ConfirmDialog
|
||||||
|
:show="showResetQuotaConfirm"
|
||||||
|
:title="t('admin.subscriptions.resetQuotaTitle')"
|
||||||
|
:message="t('admin.subscriptions.resetQuotaConfirm', { user: resettingSubscription?.user?.email })"
|
||||||
|
:confirm-text="t('admin.subscriptions.resetQuota')"
|
||||||
|
:cancel-text="t('common.cancel')"
|
||||||
|
@confirm="confirmResetQuota"
|
||||||
|
@cancel="showResetQuotaConfirm = false"
|
||||||
|
/>
|
||||||
</AppLayout>
|
</AppLayout>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -812,7 +832,10 @@ const pagination = reactive({
|
|||||||
const showAssignModal = ref(false)
|
const showAssignModal = ref(false)
|
||||||
const showExtendModal = ref(false)
|
const showExtendModal = ref(false)
|
||||||
const showRevokeDialog = ref(false)
|
const showRevokeDialog = ref(false)
|
||||||
|
const showResetQuotaConfirm = ref(false)
|
||||||
const submitting = ref(false)
|
const submitting = ref(false)
|
||||||
|
const resettingSubscription = ref<UserSubscription | null>(null)
|
||||||
|
const resettingQuota = ref(false)
|
||||||
const extendingSubscription = ref<UserSubscription | null>(null)
|
const extendingSubscription = ref<UserSubscription | null>(null)
|
||||||
const revokingSubscription = ref<UserSubscription | null>(null)
|
const revokingSubscription = ref<UserSubscription | null>(null)
|
||||||
|
|
||||||
@@ -1121,6 +1144,29 @@ const confirmRevoke = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleResetQuota = (subscription: UserSubscription) => {
|
||||||
|
resettingSubscription.value = subscription
|
||||||
|
showResetQuotaConfirm.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const confirmResetQuota = async () => {
|
||||||
|
if (!resettingSubscription.value) return
|
||||||
|
if (resettingQuota.value) return
|
||||||
|
resettingQuota.value = true
|
||||||
|
try {
|
||||||
|
await adminAPI.subscriptions.resetQuota(resettingSubscription.value.id, { daily: true, weekly: true })
|
||||||
|
appStore.showSuccess(t('admin.subscriptions.quotaResetSuccess'))
|
||||||
|
showResetQuotaConfirm.value = false
|
||||||
|
resettingSubscription.value = null
|
||||||
|
await loadSubscriptions()
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(error.response?.data?.detail || t('admin.subscriptions.failedToResetQuota'))
|
||||||
|
console.error('Error resetting quota:', error)
|
||||||
|
} finally {
|
||||||
|
resettingQuota.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
const getDaysRemaining = (expiresAt: string): number | null => {
|
const getDaysRemaining = (expiresAt: string): number | null => {
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
|
|||||||
@@ -13,8 +13,18 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||||
<ModelDistributionChart :model-stats="modelStats" :loading="chartsLoading" />
|
<ModelDistributionChart
|
||||||
<GroupDistributionChart :group-stats="groupStats" :loading="chartsLoading" />
|
v-model:metric="modelDistributionMetric"
|
||||||
|
:model-stats="modelStats"
|
||||||
|
:loading="chartsLoading"
|
||||||
|
:show-metric-toggle="true"
|
||||||
|
/>
|
||||||
|
<GroupDistributionChart
|
||||||
|
v-model:metric="groupDistributionMetric"
|
||||||
|
:group-stats="groupStats"
|
||||||
|
:loading="chartsLoading"
|
||||||
|
:show-metric-toggle="true"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
|
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
|
||||||
</div>
|
</div>
|
||||||
@@ -93,8 +103,12 @@ import type { AdminUsageLog, TrendDataPoint, ModelStat, GroupStat, AdminUser } f
|
|||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||||
|
|
||||||
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<AdminUsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
|
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<AdminUsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
|
||||||
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const groupStats = ref<GroupStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
|
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const groupStats = ref<GroupStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
|
||||||
|
const modelDistributionMetric = ref<DistributionMetric>('tokens')
|
||||||
|
const groupDistributionMetric = ref<DistributionMetric>('tokens')
|
||||||
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
|
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
|
||||||
let chartReqSeq = 0
|
let chartReqSeq = 0
|
||||||
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
|
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
|
||||||
|
|||||||
174
frontend/src/views/admin/__tests__/UsageView.spec.ts
Normal file
174
frontend/src/views/admin/__tests__/UsageView.spec.ts
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'
|
||||||
|
import { flushPromises, mount } from '@vue/test-utils'
|
||||||
|
|
||||||
|
import UsageView from '../UsageView.vue'
|
||||||
|
|
||||||
|
const { list, getStats, getSnapshotV2, getById } = vi.hoisted(() => {
|
||||||
|
vi.stubGlobal('localStorage', {
|
||||||
|
getItem: vi.fn(() => null),
|
||||||
|
setItem: vi.fn(),
|
||||||
|
removeItem: vi.fn(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
list: vi.fn(),
|
||||||
|
getStats: vi.fn(),
|
||||||
|
getSnapshotV2: vi.fn(),
|
||||||
|
getById: vi.fn(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'admin.dashboard.day': 'Day',
|
||||||
|
'admin.dashboard.hour': 'Hour',
|
||||||
|
'admin.usage.failedToLoadUser': 'Failed to load user',
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mock('@/api/admin', () => ({
|
||||||
|
adminAPI: {
|
||||||
|
usage: {
|
||||||
|
list,
|
||||||
|
getStats,
|
||||||
|
},
|
||||||
|
dashboard: {
|
||||||
|
getSnapshotV2,
|
||||||
|
},
|
||||||
|
users: {
|
||||||
|
getById,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/admin/usage', () => ({
|
||||||
|
adminUsageAPI: {
|
||||||
|
list: vi.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/stores/app', () => ({
|
||||||
|
useAppStore: () => ({
|
||||||
|
showError: vi.fn(),
|
||||||
|
showWarning: vi.fn(),
|
||||||
|
showSuccess: vi.fn(),
|
||||||
|
showInfo: vi.fn(),
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/utils/format', () => ({
|
||||||
|
formatReasoningEffort: (value: string | null | undefined) => value ?? '-',
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => messages[key] ?? key,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const AppLayoutStub = { template: '<div><slot /></div>' }
|
||||||
|
const UsageFiltersStub = { template: '<div><slot name="after-reset" /></div>' }
|
||||||
|
const ModelDistributionChartStub = {
|
||||||
|
props: ['metric'],
|
||||||
|
emits: ['update:metric'],
|
||||||
|
template: `
|
||||||
|
<div data-test="model-chart">
|
||||||
|
<span class="metric">{{ metric }}</span>
|
||||||
|
<button class="switch-metric" @click="$emit('update:metric', 'actual_cost')">switch</button>
|
||||||
|
</div>
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
const GroupDistributionChartStub = {
|
||||||
|
props: ['metric'],
|
||||||
|
emits: ['update:metric'],
|
||||||
|
template: `
|
||||||
|
<div data-test="group-chart">
|
||||||
|
<span class="metric">{{ metric }}</span>
|
||||||
|
<button class="switch-metric" @click="$emit('update:metric', 'actual_cost')">switch</button>
|
||||||
|
</div>
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('admin UsageView distribution metric toggles', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.useFakeTimers()
|
||||||
|
list.mockReset()
|
||||||
|
getStats.mockReset()
|
||||||
|
getSnapshotV2.mockReset()
|
||||||
|
getById.mockReset()
|
||||||
|
|
||||||
|
list.mockResolvedValue({
|
||||||
|
items: [],
|
||||||
|
total: 0,
|
||||||
|
pages: 0,
|
||||||
|
})
|
||||||
|
getStats.mockResolvedValue({
|
||||||
|
total_requests: 0,
|
||||||
|
total_input_tokens: 0,
|
||||||
|
total_output_tokens: 0,
|
||||||
|
total_cache_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
total_cost: 0,
|
||||||
|
total_actual_cost: 0,
|
||||||
|
average_duration_ms: 0,
|
||||||
|
})
|
||||||
|
getSnapshotV2.mockResolvedValue({
|
||||||
|
trend: [],
|
||||||
|
models: [],
|
||||||
|
groups: [],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps model and group metric toggles independent without refetching chart data', async () => {
|
||||||
|
const wrapper = mount(UsageView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
AppLayout: AppLayoutStub,
|
||||||
|
UsageStatsCards: true,
|
||||||
|
UsageFilters: UsageFiltersStub,
|
||||||
|
UsageTable: true,
|
||||||
|
UsageExportProgress: true,
|
||||||
|
UsageCleanupDialog: true,
|
||||||
|
UserBalanceHistoryModal: true,
|
||||||
|
Pagination: true,
|
||||||
|
Select: true,
|
||||||
|
Icon: true,
|
||||||
|
TokenUsageTrend: true,
|
||||||
|
ModelDistributionChart: ModelDistributionChartStub,
|
||||||
|
GroupDistributionChart: GroupDistributionChartStub,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(120)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(getSnapshotV2).toHaveBeenCalledTimes(1)
|
||||||
|
|
||||||
|
const modelChart = wrapper.find('[data-test="model-chart"]')
|
||||||
|
const groupChart = wrapper.find('[data-test="group-chart"]')
|
||||||
|
|
||||||
|
expect(modelChart.find('.metric').text()).toBe('tokens')
|
||||||
|
expect(groupChart.find('.metric').text()).toBe('tokens')
|
||||||
|
|
||||||
|
await modelChart.find('.switch-metric').trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(modelChart.find('.metric').text()).toBe('actual_cost')
|
||||||
|
expect(groupChart.find('.metric').text()).toBe('tokens')
|
||||||
|
expect(getSnapshotV2).toHaveBeenCalledTimes(1)
|
||||||
|
|
||||||
|
await groupChart.find('.switch-metric').trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(modelChart.find('.metric').text()).toBe('actual_cost')
|
||||||
|
expect(groupChart.find('.metric').text()).toBe('actual_cost')
|
||||||
|
expect(getSnapshotV2).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user