mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-26 17:34:47 +08:00
Merge branch 'develop' into release/custom-0.1.80
# Conflicts: # backend/internal/pkg/antigravity/client.go # backend/internal/service/antigravity_oauth_service.go # backend/internal/service/antigravity_oauth_service_test.go # backend/internal/service/antigravity_token_provider.go
This commit is contained in:
@@ -1372,6 +1372,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -2044,6 +2048,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
@@ -344,8 +344,16 @@ func (s *OpsService) getUsersLoadMapBestEffort(ctx context.Context, users []User
|
||||
return out
|
||||
}
|
||||
|
||||
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
|
||||
func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*UserConcurrencyInfo, *time.Time, error) {
|
||||
// GetUserConcurrencyStats returns real-time concurrency usage for active users.
|
||||
//
|
||||
// Optional filters:
|
||||
// - platformFilter: only include users who have access to groups belonging to that platform
|
||||
// - groupIDFilter: only include users who have access to that specific group
|
||||
func (s *OpsService) GetUserConcurrencyStats(
|
||||
ctx context.Context,
|
||||
platformFilter string,
|
||||
groupIDFilter *int64,
|
||||
) (map[int64]*UserConcurrencyInfo, *time.Time, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -355,6 +363,15 @@ func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*Us
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Build a set of allowed group IDs when filtering is requested.
|
||||
var allowedGroupIDs map[int64]struct{}
|
||||
if platformFilter != "" || (groupIDFilter != nil && *groupIDFilter > 0) {
|
||||
allowedGroupIDs, err = s.buildAllowedGroupIDsForFilter(ctx, platformFilter, groupIDFilter)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
collectedAt := time.Now()
|
||||
loadMap := s.getUsersLoadMapBestEffort(ctx, users)
|
||||
|
||||
@@ -365,6 +382,12 @@ func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*Us
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply group/platform filter: skip users whose AllowedGroups
|
||||
// have no intersection with the matching group IDs.
|
||||
if allowedGroupIDs != nil && !userMatchesGroupFilter(u.AllowedGroups, allowedGroupIDs) {
|
||||
continue
|
||||
}
|
||||
|
||||
load := loadMap[u.ID]
|
||||
currentInUse := int64(0)
|
||||
waiting := int64(0)
|
||||
@@ -394,3 +417,46 @@ func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*Us
|
||||
|
||||
return result, &collectedAt, nil
|
||||
}
|
||||
|
||||
// buildAllowedGroupIDsForFilter returns the set of group IDs that match the given
|
||||
// platform and/or group ID filter. It reuses listAllAccountsForOps (which already
|
||||
// supports platform filtering at the DB level) to collect group IDs from accounts.
|
||||
func (s *OpsService) buildAllowedGroupIDsForFilter(ctx context.Context, platformFilter string, groupIDFilter *int64) (map[int64]struct{}, error) {
|
||||
// Fast path: only group ID filter, no platform filter needed.
|
||||
if platformFilter == "" && groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
return map[int64]struct{}{*groupIDFilter: {}}, nil
|
||||
}
|
||||
|
||||
// Use the same account-based approach as GetConcurrencyStats to collect group IDs.
|
||||
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make(map[int64]struct{})
|
||||
for _, acc := range accounts {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
// If groupIDFilter is set, only include that specific group.
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 && grp.ID != *groupIDFilter {
|
||||
continue
|
||||
}
|
||||
groupIDs[grp.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
// userMatchesGroupFilter returns true if the user's AllowedGroups contains
|
||||
// at least one group ID in the allowed set.
|
||||
func userMatchesGroupFilter(userGroups []int64, allowedGroupIDs map[int64]struct{}) bool {
|
||||
for _, gid := range userGroups {
|
||||
if _, ok := allowedGroupIDs[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user