mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
fix: resolve cherry-pick compilation and test issues
- Add int64(0) param to SelectAccountWithLoadAwareness callers (signature change from channel scheduling refactor) - Add UsageMapHook type and struct field to StreamingProcessor - Revert Claude Max cache billing code to upstream/main (not part of channel feature) - Revert credits overages logic to upstream/main (non-channel change) - Remove Instructions field reference (non-channel OpenAI feature) - Restore sora_client_handler_test.go from upstream + add channel service nil params
This commit is contained in:
@@ -295,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
@@ -518,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||||
|
|||||||
@@ -360,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
|
|||||||
@@ -125,13 +125,6 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s
|
|||||||
return r.countValue, nil
|
return r.countValue, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubSoraGenRepo) CountByStorageType(_ context.Context, _ string, _ []string) (int64, error) {
|
|
||||||
if r.countErr != nil {
|
|
||||||
return 0, r.countErr
|
|
||||||
}
|
|
||||||
return r.countValue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 辅助函数 ====================
|
// ==================== 辅助函数 ====================
|
||||||
|
|
||||||
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
|
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
|
||||||
@@ -1664,8 +1657,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("ok")
|
fakeS3 := newFakeS3Server("ok")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||||||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||||||
@@ -1686,8 +1679,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("ok")
|
fakeS3 := newFakeS3Server("ok")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||||||
@@ -1711,8 +1704,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer badSource.Close()
|
defer badSource.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||||||
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
|
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
|
||||||
@@ -1726,8 +1719,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("fail")
|
fakeS3 := newFakeS3Server("fail")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||||||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||||||
@@ -1743,8 +1736,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("fail-second")
|
fakeS3 := newFakeS3Server("fail-second")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||||||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||||||
@@ -1815,7 +1808,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("fail")
|
fakeS3 := newFakeS3Server("fail")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Sora: config.SoraConfig{
|
Sora: config.SoraConfig{
|
||||||
Storage: config.SoraStorageConfig{
|
Storage: config.SoraStorageConfig{
|
||||||
@@ -1828,8 +1821,8 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||||||
h := &SoraClientHandler{
|
h := &SoraClientHandler{
|
||||||
objectStorage: objectStorage,
|
s3Storage: s3Storage,
|
||||||
mediaStorage: mediaStorage,
|
mediaStorage: mediaStorage,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||||||
@@ -1853,9 +1846,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -1879,9 +1872,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: expiredServer.URL + "/v.mp4",
|
MediaURL: expiredServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -1903,9 +1896,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -1913,7 +1906,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
resp := parseResponse(t, rec)
|
resp := parseResponse(t, rec)
|
||||||
data := resp["data"].(map[string]any)
|
data := resp["data"].(map[string]any)
|
||||||
require.Contains(t, data["message"], "云存储")
|
require.Contains(t, data["message"], "S3")
|
||||||
require.NotEmpty(t, data["object_key"])
|
require.NotEmpty(t, data["object_key"])
|
||||||
// 验证记录已更新为 S3 存储
|
// 验证记录已更新为 S3 存储
|
||||||
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
||||||
@@ -1935,9 +1928,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
|
|||||||
sourceServer.URL + "/v2.mp4",
|
sourceServer.URL + "/v2.mp4",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -1963,7 +1956,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
|
|
||||||
userRepo := newStubUserRepoForHandler()
|
userRepo := newStubUserRepoForHandler()
|
||||||
@@ -1973,7 +1966,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
|
|||||||
SoraStorageUsedBytes: 0,
|
SoraStorageUsedBytes: 0,
|
||||||
}
|
}
|
||||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -1997,9 +1990,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
|
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
|
||||||
repo.updateErr = fmt.Errorf("db error")
|
repo.updateErr = fmt.Errorf("db error")
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -2014,8 +2007,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("fail")
|
fakeS3 := newFakeS3Server("fail")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||||||
h.GetStorageStatus(c)
|
h.GetStorageStatus(c)
|
||||||
@@ -2030,8 +2023,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
|
|||||||
fakeS3 := newFakeS3Server("ok")
|
fakeS3 := newFakeS3Server("ok")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
|
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||||||
h.GetStorageStatus(c)
|
h.GetStorageStatus(c)
|
||||||
@@ -2460,7 +2453,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
|
|
||||||
userRepo := newStubUserRepoForHandler()
|
userRepo := newStubUserRepoForHandler()
|
||||||
userRepo.users[1] = &service.User{
|
userRepo.users[1] = &service.User{
|
||||||
@@ -2472,7 +2465,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
|
|||||||
genService: genService,
|
genService: genService,
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
soraGatewayService: soraGatewayService,
|
soraGatewayService: soraGatewayService,
|
||||||
objectStorage: objectStorage,
|
s3Storage: s3Storage,
|
||||||
quotaService: quotaService,
|
quotaService: quotaService,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2522,7 +2515,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
|
|||||||
// ==================== cleanupStoredMedia 直接测试 ====================
|
// ==================== cleanupStoredMedia 直接测试 ====================
|
||||||
|
|
||||||
func TestCleanupStoredMedia_S3Path(t *testing.T) {
|
func TestCleanupStoredMedia_S3Path(t *testing.T) {
|
||||||
// S3 清理路径:objectStorage 为 nil 时不 panic
|
// S3 清理路径:s3Storage 为 nil 时不 panic
|
||||||
h := &SoraClientHandler{}
|
h := &SoraClientHandler{}
|
||||||
// 不应 panic
|
// 不应 panic
|
||||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||||||
@@ -2969,7 +2962,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
|
|
||||||
// 用户配额已满
|
// 用户配额已满
|
||||||
@@ -2980,7 +2973,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
|
|||||||
SoraStorageUsedBytes: 10,
|
SoraStorageUsedBytes: 10,
|
||||||
}
|
}
|
||||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -3002,13 +2995,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
|
|||||||
StorageType: "upstream",
|
StorageType: "upstream",
|
||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
|
|
||||||
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
|
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
|
||||||
userRepo := newStubUserRepoForHandler()
|
userRepo := newStubUserRepoForHandler()
|
||||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -3029,9 +3022,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
|
|||||||
MediaURL: "",
|
MediaURL: "",
|
||||||
MediaURLs: []string{},
|
MediaURLs: []string{},
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -3056,9 +3049,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
|
|||||||
MediaURL: sourceServer.URL + "/v1.mp4",
|
MediaURL: sourceServer.URL + "/v1.mp4",
|
||||||
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
|
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
|
||||||
}
|
}
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -3081,7 +3074,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
|||||||
MediaURL: sourceServer.URL + "/v.mp4",
|
MediaURL: sourceServer.URL + "/v.mp4",
|
||||||
}
|
}
|
||||||
repo.updateErr = fmt.Errorf("db error")
|
repo.updateErr = fmt.Errorf("db error")
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||||
|
|
||||||
userRepo := newStubUserRepoForHandler()
|
userRepo := newStubUserRepoForHandler()
|
||||||
@@ -3091,7 +3084,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
|||||||
SoraStorageUsedBytes: 0,
|
SoraStorageUsedBytes: 0,
|
||||||
}
|
}
|
||||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||||
|
|
||||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||||
@@ -3104,8 +3097,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
|||||||
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
||||||
fakeS3 := newFakeS3Server("ok")
|
fakeS3 := newFakeS3Server("ok")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
|
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
|
||||||
}
|
}
|
||||||
@@ -3113,8 +3106,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
|||||||
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
|
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
|
||||||
fakeS3 := newFakeS3Server("fail")
|
fakeS3 := newFakeS3Server("fail")
|
||||||
defer fakeS3.Close()
|
defer fakeS3.Close()
|
||||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||||
|
|
||||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
var lastFailoverHeaders http.Header
|
var lastFailoverHeaders http.Header
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqLog.Warn("sora.account_select_failed",
|
reqLog.Warn("sora.account_select_failed",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||||
|
type UsageMapHook func(usageMap map[string]any)
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
|
usageMapHook UsageMapHook
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
|
|||||||
@@ -27,13 +27,12 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := &ResponsesRequest{
|
out := &ResponsesRequest{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Instructions: req.Instructions,
|
Input: inputJSON,
|
||||||
Input: inputJSON,
|
Temperature: req.Temperature,
|
||||||
Temperature: req.Temperature,
|
TopP: req.TopP,
|
||||||
TopP: req.TopP,
|
Stream: true, // upstream always streams
|
||||||
Stream: true, // upstream always streams
|
Include: []string{"reasoning.encrypted_content"},
|
||||||
Include: []string{"reasoning.encrypted_content"},
|
|
||||||
ServiceTier: req.ServiceTier,
|
ServiceTier: req.ServiceTier,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -278,13 +278,9 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||||
// 此函数在积分注入后失败时调用(预检查注入 + attemptCreditsOveragesRetry 两条路径)。
|
// 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用,
|
||||||
// - 429 + 非单模型限流:积分注入后仍 429 → 标记耗尽。
|
// 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted",
|
||||||
// - 429 + 单模型限流("exhausted your capacity on this model"):该模型免费配额用完,
|
// 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。
|
||||||
// 积分注入对此无效,但账号积分对其他模型可能仍可用 → 不标记积分耗尽。
|
|
||||||
// - 403 等其他 4xx:检查 body 是否包含积分不足的关键词。
|
|
||||||
//
|
|
||||||
// clearCreditsExhausted 会在后续成功时自动清除。
|
|
||||||
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||||
if reqErr != nil || resp == nil {
|
if reqErr != nil || resp == nil {
|
||||||
return false
|
return false
|
||||||
@@ -292,16 +288,10 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err
|
|||||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
bodyLower := strings.ToLower(string(respBody))
|
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||||
// 积分注入后仍 429
|
return false
|
||||||
if resp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
// 单模型配额耗尽:积分注入对此无效,不标记整个账号积分耗尽
|
|
||||||
if strings.Contains(bodyLower, "exhausted your capacity on this model") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
// 其他 4xx:关键词匹配(如 403 + "Insufficient credits")
|
bodyLower := strings.ToLower(string(respBody))
|
||||||
for _, keyword := range creditsExhaustedKeywords {
|
for _, keyword := range creditsExhaustedKeywords {
|
||||||
if strings.Contains(bodyLower, keyword) {
|
if strings.Contains(bodyLower, keyword) {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -418,13 +418,7 @@ func TestShouldMarkCreditsExhausted(t *testing.T) {
|
|||||||
require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
|
require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("单模型配额耗尽不标记(积分对此无效)", func(t *testing.T) {
|
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
|
||||||
body := []byte(`{"error":{"code":429,"message":"You have exhausted your capacity on this model. Your quota will reset after 146h11m17s.","status":"RESOURCE_EXHAUSTED"}}`)
|
|
||||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("429 结构化限流也标记(积分注入后仍 429 即为耗尽)", func(t *testing.T) {
|
|
||||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||||
|
|||||||
@@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
|||||||
modelsListCacheTTL: time.Minute,
|
modelsListCacheTTL: time.Minute,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
||||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
|
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
|
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
|
||||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
|
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
|
|||||||
@@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil, // No concurrency service
|
concurrencyService: nil, // No concurrency service
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil, // legacy path
|
concurrencyService: nil, // legacy path
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
excludedIDs := map[int64]struct{}{1: {}}
|
excludedIDs := map[int64]struct{}{1: {}}
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
@@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.WaitPlan)
|
require.NotNil(t, result.WaitPlan)
|
||||||
@@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.WaitPlan)
|
require.NotNil(t, result.WaitPlan)
|
||||||
@@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.WaitPlan)
|
require.NotNil(t, result.WaitPlan)
|
||||||
@@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.WaitPlan)
|
require.NotNil(t, result.WaitPlan)
|
||||||
@@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
excluded := map[int64]struct{}{1: {}}
|
excluded := map[int64]struct{}{1: {}}
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.ErrorIs(t, err, ErrClaudeCodeOnly)
|
require.ErrorIs(t, err, ErrClaudeCodeOnly)
|
||||||
@@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.WaitPlan)
|
require.NotNil(t, result.WaitPlan)
|
||||||
@@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
|
|||||||
@@ -7782,7 +7782,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
ChannelUsageFields: input.ChannelUsageFields,
|
ChannelUsageFields: input.ChannelUsageFields,
|
||||||
}, &recordUsageOpts{
|
}, &recordUsageOpts{
|
||||||
ParsedRequest: input.ParsedRequest,
|
|
||||||
EnableClaudePath: true,
|
EnableClaudePath: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -7867,21 +7866,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claude Max cache billing policy(仅 Claude 路径启用)
|
|
||||||
cacheTTLOverridden := false
|
|
||||||
simulatedClaudeMax := false
|
|
||||||
if opts.EnableClaudePath {
|
|
||||||
var apiKeyGroup *Group
|
|
||||||
if apiKey != nil {
|
|
||||||
apiKeyGroup = apiKey.Group
|
|
||||||
}
|
|
||||||
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
|
|
||||||
simulatedClaudeMax = claudeMaxOutcome.Simulated ||
|
|
||||||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
cacheTTLOverridden := false
|
||||||
|
if account.IsCacheTTLOverrideEnabled() {
|
||||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
|||||||
if s.gatewayService == nil {
|
if s.gatewayService == nil {
|
||||||
return nil, fmt.Errorf("gateway service not available")
|
return nil, fmt.Errorf("gateway service not available")
|
||||||
}
|
}
|
||||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user