mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-10 10:04:46 +08:00
Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312
This commit is contained in:
@@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly {
|
||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
290
backend/internal/handler/openai_chat_completions.go
Normal file
290
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +39,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user