2025-12-18 13:50:39 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"bufio"
|
|
|
|
|
|
"bytes"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"crypto/sha256"
|
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
|
|
|
|
|
"log"
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"regexp"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
2025-12-24 21:07:21 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
2025-12-25 14:47:19 +08:00
|
|
|
|
"github.com/tidwall/gjson"
|
|
|
|
|
|
"github.com/tidwall/sjson"
|
2025-12-18 13:50:39 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
const (
|
2025-12-19 11:12:41 +08:00
|
|
|
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
|
|
|
|
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
|
|
|
|
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
2025-12-18 13:50:39 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-29 16:52:55 +08:00
|
|
|
|
// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置)
|
|
|
|
|
|
// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值
|
|
|
|
|
|
const ctxKeyForcePlatform = "ctx_force_platform"
|
|
|
|
|
|
|
2025-12-26 03:49:55 -08:00
|
|
|
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
|
|
|
|
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
|
|
|
|
|
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// allowedHeaders 白名单headers(参考CRS项目)
|
|
|
|
|
|
var allowedHeaders = map[string]bool{
|
2025-12-18 18:14:20 +08:00
|
|
|
|
"accept": true,
|
|
|
|
|
|
"x-stainless-retry-count": true,
|
|
|
|
|
|
"x-stainless-timeout": true,
|
|
|
|
|
|
"x-stainless-lang": true,
|
|
|
|
|
|
"x-stainless-package-version": true,
|
|
|
|
|
|
"x-stainless-os": true,
|
|
|
|
|
|
"x-stainless-arch": true,
|
|
|
|
|
|
"x-stainless-runtime": true,
|
|
|
|
|
|
"x-stainless-runtime-version": true,
|
|
|
|
|
|
"x-stainless-helper-method": true,
|
2025-12-18 13:50:39 +08:00
|
|
|
|
"anthropic-dangerous-direct-browser-access": true,
|
2025-12-18 18:14:20 +08:00
|
|
|
|
"anthropic-version": true,
|
|
|
|
|
|
"x-app": true,
|
|
|
|
|
|
"anthropic-beta": true,
|
|
|
|
|
|
"accept-language": true,
|
|
|
|
|
|
"sec-fetch-mode": true,
|
|
|
|
|
|
"user-agent": true,
|
|
|
|
|
|
"content-type": true,
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-25 17:15:01 +08:00
|
|
|
|
// GatewayCache defines cache operations for gateway service
|
|
|
|
|
|
type GatewayCache interface {
|
|
|
|
|
|
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
|
|
|
|
|
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
|
|
|
|
|
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// ClaudeUsage 表示Claude API返回的usage信息
|
|
|
|
|
|
type ClaudeUsage struct {
|
|
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
|
|
|
|
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ForwardResult 转发结果
|
|
|
|
|
|
type ForwardResult struct {
|
|
|
|
|
|
RequestID string
|
|
|
|
|
|
Usage ClaudeUsage
|
|
|
|
|
|
Model string
|
|
|
|
|
|
Stream bool
|
|
|
|
|
|
Duration time.Duration
|
|
|
|
|
|
FirstTokenMs *int // 首字时间(流式请求)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
|
|
|
|
|
type UpstreamFailoverError struct {
|
|
|
|
|
|
StatusCode int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (e *UpstreamFailoverError) Error() string {
|
|
|
|
|
|
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// GatewayService handles API gateway operations
|
|
|
|
|
|
type GatewayService struct {
|
2025-12-25 17:15:01 +08:00
|
|
|
|
accountRepo AccountRepository
|
2025-12-29 09:44:39 +08:00
|
|
|
|
groupRepo GroupRepository
|
2025-12-25 17:15:01 +08:00
|
|
|
|
usageLogRepo UsageLogRepository
|
|
|
|
|
|
userRepo UserRepository
|
|
|
|
|
|
userSubRepo UserSubscriptionRepository
|
|
|
|
|
|
cache GatewayCache
|
2025-12-18 13:50:39 +08:00
|
|
|
|
cfg *config.Config
|
|
|
|
|
|
billingService *BillingService
|
|
|
|
|
|
rateLimitService *RateLimitService
|
|
|
|
|
|
billingCacheService *BillingCacheService
|
|
|
|
|
|
identityService *IdentityService
|
2025-12-25 17:15:01 +08:00
|
|
|
|
httpUpstream HTTPUpstream
|
2025-12-28 08:07:15 +08:00
|
|
|
|
deferredService *DeferredService
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// NewGatewayService creates a new GatewayService
|
2025-12-19 21:26:19 +08:00
|
|
|
|
func NewGatewayService(
|
2025-12-25 17:15:01 +08:00
|
|
|
|
accountRepo AccountRepository,
|
2025-12-29 09:44:39 +08:00
|
|
|
|
groupRepo GroupRepository,
|
2025-12-25 17:15:01 +08:00
|
|
|
|
usageLogRepo UsageLogRepository,
|
|
|
|
|
|
userRepo UserRepository,
|
|
|
|
|
|
userSubRepo UserSubscriptionRepository,
|
|
|
|
|
|
cache GatewayCache,
|
2025-12-19 21:26:19 +08:00
|
|
|
|
cfg *config.Config,
|
|
|
|
|
|
billingService *BillingService,
|
|
|
|
|
|
rateLimitService *RateLimitService,
|
|
|
|
|
|
billingCacheService *BillingCacheService,
|
|
|
|
|
|
identityService *IdentityService,
|
2025-12-25 17:15:01 +08:00
|
|
|
|
httpUpstream HTTPUpstream,
|
2025-12-28 08:07:15 +08:00
|
|
|
|
deferredService *DeferredService,
|
2025-12-19 21:26:19 +08:00
|
|
|
|
) *GatewayService {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
return &GatewayService{
|
2025-12-19 21:26:19 +08:00
|
|
|
|
accountRepo: accountRepo,
|
2025-12-29 09:44:39 +08:00
|
|
|
|
groupRepo: groupRepo,
|
2025-12-19 21:26:19 +08:00
|
|
|
|
usageLogRepo: usageLogRepo,
|
|
|
|
|
|
userRepo: userRepo,
|
|
|
|
|
|
userSubRepo: userSubRepo,
|
2025-12-19 23:39:28 +08:00
|
|
|
|
cache: cache,
|
2025-12-18 13:50:39 +08:00
|
|
|
|
cfg: cfg,
|
|
|
|
|
|
billingService: billingService,
|
|
|
|
|
|
rateLimitService: rateLimitService,
|
|
|
|
|
|
billingCacheService: billingCacheService,
|
|
|
|
|
|
identityService: identityService,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
httpUpstream: httpUpstream,
|
2025-12-28 08:07:15 +08:00
|
|
|
|
deferredService: deferredService,
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GenerateSessionHash 从请求体计算粘性会话hash
|
|
|
|
|
|
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
var req map[string]any
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if metadata, ok := req["metadata"].(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if userID, ok := metadata["user_id"].(string); ok {
|
|
|
|
|
|
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
|
|
|
|
|
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
|
|
|
|
|
return match[1]
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 提取带cache_control: {type: "ephemeral"}的内容
|
|
|
|
|
|
cacheableContent := s.extractCacheableContent(req)
|
|
|
|
|
|
if cacheableContent != "" {
|
|
|
|
|
|
return s.hashContent(cacheableContent)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. Fallback: 使用system内容
|
|
|
|
|
|
if system := req["system"]; system != nil {
|
|
|
|
|
|
systemText := s.extractTextFromSystem(system)
|
|
|
|
|
|
if systemText != "" {
|
|
|
|
|
|
return s.hashContent(systemText)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4. 最后fallback: 使用第一条消息
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
|
|
|
|
|
if firstMsg, ok := messages[0].(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
msgText := s.extractTextFromContent(firstMsg["content"])
|
|
|
|
|
|
if msgText != "" {
|
|
|
|
|
|
return s.hashContent(msgText)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 16:19:40 +08:00
|
|
|
|
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var content string
|
|
|
|
|
|
|
|
|
|
|
|
// 检查system中的cacheable内容
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if system, ok := req["system"].([]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
for _, part := range system {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if partMap, ok := part.(map[string]any); ok {
|
|
|
|
|
|
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if cc["type"] == "ephemeral" {
|
|
|
|
|
|
if text, ok := partMap["text"].(string); ok {
|
|
|
|
|
|
content += text
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 检查messages中的cacheable内容
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if messages, ok := req["messages"].([]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
for _, msg := range messages {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if msgMap, ok := msg.(map[string]any); ok {
|
|
|
|
|
|
if msgContent, ok := msgMap["content"].([]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
for _, part := range msgContent {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if partMap, ok := part.(map[string]any); ok {
|
|
|
|
|
|
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if cc["type"] == "ephemeral" {
|
|
|
|
|
|
// 找到cacheable内容,提取第一条消息的文本
|
|
|
|
|
|
return s.extractTextFromContent(msgMap["content"])
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return content
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 16:19:40 +08:00
|
|
|
|
func (s *GatewayService) extractTextFromSystem(system any) string {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
switch v := system.(type) {
|
|
|
|
|
|
case string:
|
|
|
|
|
|
return v
|
2025-12-20 16:19:40 +08:00
|
|
|
|
case []any:
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var texts []string
|
|
|
|
|
|
for _, part := range v {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if partMap, ok := part.(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if text, ok := partMap["text"].(string); ok {
|
|
|
|
|
|
texts = append(texts, text)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return strings.Join(texts, "")
|
|
|
|
|
|
}
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 16:19:40 +08:00
|
|
|
|
func (s *GatewayService) extractTextFromContent(content any) string {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
switch v := content.(type) {
|
|
|
|
|
|
case string:
|
|
|
|
|
|
return v
|
2025-12-20 16:19:40 +08:00
|
|
|
|
case []any:
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var texts []string
|
|
|
|
|
|
for _, part := range v {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
if partMap, ok := part.(map[string]any); ok {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if partMap["type"] == "text" {
|
|
|
|
|
|
if text, ok := partMap["text"].(string); ok {
|
|
|
|
|
|
texts = append(texts, text)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return strings.Join(texts, "")
|
|
|
|
|
|
}
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *GatewayService) hashContent(content string) string {
|
|
|
|
|
|
hash := sha256.Sum256([]byte(content))
|
|
|
|
|
|
return hex.EncodeToString(hash[:16]) // 32字符
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceModelInBody 替换请求体中的model字段
|
|
|
|
|
|
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
var req map[string]any
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
req["model"] = newModel
|
|
|
|
|
|
newBody, err := json.Marshal(req)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
return newBody
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SelectAccount 选择账号(粘性会话+优先级)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
|
|
|
|
|
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
2025-12-29 16:52:55 +08:00
|
|
|
|
// 优先检查 context 中的强制平台(/antigravity 路由)
|
2025-12-29 09:44:39 +08:00
|
|
|
|
var platform string
|
2025-12-29 16:52:55 +08:00
|
|
|
|
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
|
|
|
|
|
|
if hasForcePlatform && forcePlatform != "" {
|
|
|
|
|
|
platform = forcePlatform
|
|
|
|
|
|
} else if groupID != nil {
|
|
|
|
|
|
// 根据分组 platform 决定查询哪种账号
|
2025-12-29 09:44:39 +08:00
|
|
|
|
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("get group failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
platform = group.Platform
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 无分组时只使用原生 anthropic 平台
|
|
|
|
|
|
platform = PlatformAnthropic
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
2025-12-29 16:52:55 +08:00
|
|
|
|
// 注意:强制平台模式不走混合调度
|
|
|
|
|
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
2025-12-29 09:44:39 +08:00
|
|
|
|
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-29 16:52:55 +08:00
|
|
|
|
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
|
|
|
|
|
|
if hasForcePlatform && groupID != nil {
|
|
|
|
|
|
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
return account, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
// 分组中找不到,回退查询全部该平台账户
|
|
|
|
|
|
groupID = nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
2025-12-29 09:44:39 +08:00
|
|
|
|
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
2025-12-28 17:48:52 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-29 09:44:39 +08:00
|
|
|
|
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
|
|
|
|
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 1. 查询粘性会话
|
|
|
|
|
|
if sessionHash != "" {
|
2025-12-19 23:39:28 +08:00
|
|
|
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if err == nil && accountID > 0 {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
|
|
|
|
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
2025-12-29 09:44:39 +08:00
|
|
|
|
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
|
|
|
|
|
|
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
|
|
|
|
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return account, nil
|
2025-12-20 15:29:52 +08:00
|
|
|
|
}
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-29 09:44:39 +08:00
|
|
|
|
// 2. 获取可调度账号列表(单平台)
|
|
|
|
|
|
var accounts []Account
|
|
|
|
|
|
var err error
|
2025-12-29 03:17:25 +08:00
|
|
|
|
if s.cfg.RunMode == config.RunModeSimple {
|
|
|
|
|
|
// 简易模式:忽略 groupID,查询所有可用账号
|
2025-12-29 17:04:40 +08:00
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
2025-12-29 03:17:25 +08:00
|
|
|
|
} else if groupID != nil {
|
2025-12-29 09:44:39 +08:00
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 按优先级+最久未用选择(考虑模型支持)
|
|
|
|
|
|
var selected *Account
|
|
|
|
|
|
for i := range accounts {
|
|
|
|
|
|
acc := &accounts[i]
|
|
|
|
|
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc.Priority < selected.Priority {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
} else if acc.Priority == selected.Priority {
|
|
|
|
|
|
switch {
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (never used is preferred)
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (both never used)
|
|
|
|
|
|
default:
|
|
|
|
|
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
if requestedModel != "" {
|
|
|
|
|
|
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, errors.New("no available accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4. 建立粘性绑定
|
|
|
|
|
|
if sessionHash != "" {
|
|
|
|
|
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
|
|
|
|
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return selected, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
|
|
|
|
|
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
|
|
|
|
|
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
|
|
|
|
|
platforms := []string{nativePlatform, PlatformAntigravity}
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 查询粘性会话
|
|
|
|
|
|
if sessionHash != "" {
|
|
|
|
|
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
|
|
|
|
|
if err == nil && accountID > 0 {
|
|
|
|
|
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
|
|
|
|
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
|
|
|
|
|
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
|
|
|
|
|
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
|
|
|
|
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
|
|
|
|
|
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
|
|
|
|
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return account, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 获取可调度账号列表
|
2025-12-26 15:40:24 +08:00
|
|
|
|
var accounts []Account
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var err error
|
|
|
|
|
|
if groupID != nil {
|
2025-12-28 17:48:52 +08:00
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
} else {
|
2025-12-28 17:48:52 +08:00
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-29 09:44:39 +08:00
|
|
|
|
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
var selected *Account
|
2025-12-18 13:50:39 +08:00
|
|
|
|
for i := range accounts {
|
|
|
|
|
|
acc := &accounts[i]
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2025-12-29 09:44:39 +08:00
|
|
|
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
|
|
|
|
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2025-12-28 17:48:52 +08:00
|
|
|
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if acc.Priority < selected.Priority {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
} else if acc.Priority == selected.Priority {
|
2025-12-25 21:24:44 -08:00
|
|
|
|
switch {
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
2025-12-18 13:50:39 +08:00
|
|
|
|
selected = acc
|
2025-12-25 21:24:44 -08:00
|
|
|
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (never used is preferred)
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (both never used)
|
|
|
|
|
|
default:
|
|
|
|
|
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
}
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
if requestedModel != "" {
|
|
|
|
|
|
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, errors.New("no available accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4. 建立粘性绑定
|
|
|
|
|
|
if sessionHash != "" {
|
2025-12-20 15:29:52 +08:00
|
|
|
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
|
|
|
|
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
|
|
|
|
|
}
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return selected, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-28 17:48:52 +08:00
|
|
|
|
// isModelSupportedByAccount 根据账户平台检查模型支持
|
|
|
|
|
|
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
|
|
|
|
|
if account.Platform == PlatformAntigravity {
|
|
|
|
|
|
// Antigravity 平台使用专门的模型支持检查
|
|
|
|
|
|
return IsAntigravityModelSupported(requestedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
// 其他平台使用账户的模型支持检查
|
|
|
|
|
|
return account.IsModelSupported(requestedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
|
|
|
|
|
func IsAntigravityModelSupported(requestedModel string) bool {
|
|
|
|
|
|
// 直接支持的模型
|
|
|
|
|
|
if antigravitySupportedModels[requestedModel] {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
// 可映射的模型
|
|
|
|
|
|
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
// Gemini 前缀透传
|
|
|
|
|
|
if strings.HasPrefix(requestedModel, "gemini-") {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
|
|
|
|
|
|
if strings.HasPrefix(requestedModel, "claude-") {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// GetAccessToken 获取账号凭证
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
switch account.Type {
|
2025-12-26 15:40:24 +08:00
|
|
|
|
case AccountTypeOAuth, AccountTypeSetupToken:
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// Both oauth and setup-token use OAuth token flow
|
|
|
|
|
|
return s.getOAuthToken(ctx, account)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
case AccountTypeApiKey:
|
2025-12-18 13:50:39 +08:00
|
|
|
|
apiKey := account.GetCredential("api_key")
|
|
|
|
|
|
if apiKey == "" {
|
|
|
|
|
|
return "", "", errors.New("api_key not found in credentials")
|
|
|
|
|
|
}
|
|
|
|
|
|
return apiKey, "apikey", nil
|
|
|
|
|
|
default:
|
|
|
|
|
|
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
accessToken := account.GetCredential("access_token")
|
2025-12-20 13:01:58 +08:00
|
|
|
|
if accessToken == "" {
|
|
|
|
|
|
return "", "", errors.New("access_token not found in credentials")
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
2025-12-20 13:01:58 +08:00
|
|
|
|
// Token刷新由后台 TokenRefreshService 处理,此处只返回当前token
|
2025-12-18 13:50:39 +08:00
|
|
|
|
return accessToken, "oauth", nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 16:55:46 +08:00
|
|
|
|
// 重试相关常量
|
|
|
|
|
|
const (
|
2025-12-26 14:19:57 +08:00
|
|
|
|
maxRetries = 10 // 最大重试次数
|
|
|
|
|
|
retryDelay = 3 * time.Second // 重试等待时间
|
2025-12-24 16:55:46 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
|
2025-12-24 16:55:46 +08:00
|
|
|
|
// OAuth/Setup Token 账号:仅 403 重试
|
|
|
|
|
|
if account.IsOAuth() {
|
|
|
|
|
|
return statusCode == 403
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// API Key 账号:未配置的错误码重试
|
|
|
|
|
|
return !account.ShouldHandleErrorCode(statusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
|
|
|
|
|
|
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
|
|
|
|
|
switch statusCode {
|
|
|
|
|
|
case 401, 403, 429, 529:
|
|
|
|
|
|
return true
|
|
|
|
|
|
default:
|
|
|
|
|
|
return statusCode >= 500
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// Forward 转发请求到Claude API
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
startTime := time.Now()
|
|
|
|
|
|
|
|
|
|
|
|
// 解析请求获取model和stream
|
|
|
|
|
|
var req struct {
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
Stream bool `json:"stream"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("parse request: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-25 14:47:19 +08:00
|
|
|
|
if !gjson.GetBytes(body, "system").Exists() {
|
|
|
|
|
|
body, _ = sjson.SetBytes(body, "system", []any{
|
|
|
|
|
|
map[string]any{
|
|
|
|
|
|
"type": "text",
|
|
|
|
|
|
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
|
|
|
|
|
"cache_control": map[string]string{
|
|
|
|
|
|
"type": "ephemeral",
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 应用模型映射(仅对apikey类型账号)
|
|
|
|
|
|
originalModel := req.Model
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeApiKey {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
mappedModel := account.GetMappedModel(req.Model)
|
|
|
|
|
|
if mappedModel != req.Model {
|
|
|
|
|
|
// 替换请求体中的模型名
|
|
|
|
|
|
body = s.replaceModelInBody(body, mappedModel)
|
|
|
|
|
|
req.Model = mappedModel
|
|
|
|
|
|
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 获取凭证
|
|
|
|
|
|
token, tokenType, err := s.GetAccessToken(ctx, account)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 11:56:11 +08:00
|
|
|
|
// 获取代理URL
|
|
|
|
|
|
proxyURL := ""
|
|
|
|
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
|
|
|
|
proxyURL = account.Proxy.URL()
|
2025-12-18 18:14:20 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 16:55:46 +08:00
|
|
|
|
// 重试循环
|
|
|
|
|
|
var resp *http.Response
|
|
|
|
|
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
|
|
|
|
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
|
|
|
|
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
|
|
|
|
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 检查是否需要重试
|
|
|
|
|
|
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
|
|
|
|
|
if attempt < maxRetries {
|
|
|
|
|
|
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
|
|
|
|
|
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
|
|
|
|
|
_ = resp.Body.Close()
|
|
|
|
|
|
time.Sleep(retryDelay)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
// 最后一次尝试也失败,跳出循环处理重试耗尽
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 不需要重试(成功或不可重试的错误),跳出循环
|
|
|
|
|
|
break
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
2025-12-20 15:29:52 +08:00
|
|
|
|
defer func() { _ = resp.Body.Close() }()
|
2025-12-18 13:50:39 +08:00
|
|
|
|
|
2025-12-24 16:55:46 +08:00
|
|
|
|
// 处理重试耗尽的情况
|
|
|
|
|
|
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
|
|
|
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
|
|
|
|
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
|
|
|
|
}
|
2025-12-24 16:55:46 +08:00
|
|
|
|
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
// 处理可切换账号的错误
|
|
|
|
|
|
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
|
|
|
|
s.handleFailoverSideEffects(ctx, resp, account)
|
|
|
|
|
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 16:55:46 +08:00
|
|
|
|
// 处理错误响应(不可重试的错误)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if resp.StatusCode >= 400 {
|
|
|
|
|
|
return s.handleErrorResponse(ctx, resp, c, account)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 处理正常响应
|
|
|
|
|
|
var usage *ClaudeUsage
|
|
|
|
|
|
var firstTokenMs *int
|
|
|
|
|
|
if req.Stream {
|
|
|
|
|
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
usage = streamResult.usage
|
|
|
|
|
|
firstTokenMs = streamResult.firstTokenMs
|
|
|
|
|
|
} else {
|
|
|
|
|
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &ForwardResult{
|
|
|
|
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|
|
|
|
|
Usage: *usage,
|
|
|
|
|
|
Model: originalModel, // 使用原始模型用于计费和日志
|
|
|
|
|
|
Stream: req.Stream,
|
|
|
|
|
|
Duration: time.Since(startTime),
|
|
|
|
|
|
FirstTokenMs: firstTokenMs,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 确定目标URL
|
|
|
|
|
|
targetURL := claudeAPIURL
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeApiKey {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
baseURL := account.GetBaseURL()
|
|
|
|
|
|
targetURL = baseURL + "/v1/messages"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth账号:应用统一指纹
|
2025-12-25 17:15:01 +08:00
|
|
|
|
var fingerprint *Fingerprint
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if account.IsOAuth() && s.identityService != nil {
|
|
|
|
|
|
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
|
|
|
|
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
|
|
|
|
|
|
// 失败时降级为透传原始headers
|
|
|
|
|
|
} else {
|
|
|
|
|
|
fingerprint = fp
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
|
|
|
|
|
|
accountUUID := account.GetExtraString("account_uuid")
|
|
|
|
|
|
if accountUUID != "" && fp.ClientID != "" {
|
|
|
|
|
|
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
|
|
|
|
|
body = newBody
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 设置认证头
|
|
|
|
|
|
if tokenType == "oauth" {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
req.Header.Set("authorization", "Bearer "+token)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
req.Header.Set("x-api-key", token)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 白名单透传headers
|
|
|
|
|
|
for key, values := range c.Request.Header {
|
|
|
|
|
|
lowerKey := strings.ToLower(key)
|
|
|
|
|
|
if allowedHeaders[lowerKey] {
|
|
|
|
|
|
for _, v := range values {
|
|
|
|
|
|
req.Header.Add(key, v)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头)
|
|
|
|
|
|
if fingerprint != nil {
|
|
|
|
|
|
s.identityService.ApplyFingerprint(req, fingerprint)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 确保必要的headers存在
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if req.Header.Get("content-type") == "" {
|
|
|
|
|
|
req.Header.Set("content-type", "application/json")
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
if req.Header.Get("anthropic-version") == "" {
|
|
|
|
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
|
|
|
|
|
if tokenType == "oauth" {
|
|
|
|
|
|
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 11:56:11 +08:00
|
|
|
|
return req, nil
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getBetaHeader 处理anthropic-beta header
|
|
|
|
|
|
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
|
|
|
|
|
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
|
|
|
|
|
|
// 如果客户端传了anthropic-beta
|
|
|
|
|
|
if clientBetaHeader != "" {
|
|
|
|
|
|
// 已包含oauth beta则直接返回
|
2025-12-19 15:22:52 +08:00
|
|
|
|
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
return clientBetaHeader
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 需要添加oauth beta
|
|
|
|
|
|
parts := strings.Split(clientBetaHeader, ",")
|
|
|
|
|
|
for i, p := range parts {
|
|
|
|
|
|
parts[i] = strings.TrimSpace(p)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 在claude-code-20250219后面插入oauth beta
|
|
|
|
|
|
claudeCodeIdx := -1
|
|
|
|
|
|
for i, p := range parts {
|
2025-12-19 15:22:52 +08:00
|
|
|
|
if p == claude.BetaClaudeCode {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
claudeCodeIdx = i
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if claudeCodeIdx >= 0 {
|
|
|
|
|
|
// 在claude-code后面插入
|
|
|
|
|
|
newParts := make([]string, 0, len(parts)+1)
|
|
|
|
|
|
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
|
2025-12-19 15:22:52 +08:00
|
|
|
|
newParts = append(newParts, claude.BetaOAuth)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
|
|
|
|
|
|
return strings.Join(newParts, ",")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 没有claude-code,放在第一位
|
2025-12-19 15:22:52 +08:00
|
|
|
|
return claude.BetaOAuth + "," + clientBetaHeader
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 客户端没传,根据模型生成
|
|
|
|
|
|
var modelID string
|
2025-12-20 16:19:40 +08:00
|
|
|
|
var reqMap map[string]any
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if json.Unmarshal(body, &reqMap) == nil {
|
|
|
|
|
|
if m, ok := reqMap["model"].(string); ok {
|
|
|
|
|
|
modelID = m
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// haiku模型不需要claude-code beta
|
|
|
|
|
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
2025-12-19 15:22:52 +08:00
|
|
|
|
return claude.HaikuBetaHeader
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-19 15:22:52 +08:00
|
|
|
|
return claude.DefaultBetaHeader
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
|
|
|
|
|
|
|
|
// 处理上游错误,标记账号状态
|
|
|
|
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
|
|
|
|
|
|
|
|
|
|
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
|
|
|
|
|
var errType, errMsg string
|
|
|
|
|
|
var statusCode int
|
|
|
|
|
|
|
|
|
|
|
|
switch resp.StatusCode {
|
2025-12-25 14:47:19 +08:00
|
|
|
|
case 400:
|
|
|
|
|
|
c.Data(http.StatusBadRequest, "application/json", body)
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
case 401:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream authentication failed, please contact administrator"
|
|
|
|
|
|
case 403:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream access forbidden, please contact administrator"
|
|
|
|
|
|
case 429:
|
|
|
|
|
|
statusCode = http.StatusTooManyRequests
|
|
|
|
|
|
errType = "rate_limit_error"
|
|
|
|
|
|
errMsg = "Upstream rate limit exceeded, please retry later"
|
|
|
|
|
|
case 529:
|
|
|
|
|
|
statusCode = http.StatusServiceUnavailable
|
|
|
|
|
|
errType = "overloaded_error"
|
|
|
|
|
|
errMsg = "Upstream service overloaded, please retry later"
|
|
|
|
|
|
case 500, 502, 503, 504:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream service temporarily unavailable"
|
|
|
|
|
|
default:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream request failed"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 返回自定义错误响应
|
|
|
|
|
|
c.JSON(statusCode, gin.H{
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": errType,
|
|
|
|
|
|
"message": errMsg,
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
2025-12-24 16:55:46 +08:00
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
|
|
statusCode := resp.StatusCode
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth/Setup Token 账号的 403:标记账号异常
|
|
|
|
|
|
if account.IsOAuth() && statusCode == 403 {
|
|
|
|
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
|
|
|
|
|
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// API Key 未配置错误码:不标记账号状态
|
|
|
|
|
|
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
|
|
|
|
|
}
|
2025-12-27 11:44:00 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
|
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// handleRetryExhaustedError 处理重试耗尽后的错误
|
|
|
|
|
|
// OAuth 403:标记账号异常
|
|
|
|
|
|
// API Key 未配置错误码:仅返回错误,不标记账号
|
|
|
|
|
|
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
|
|
|
|
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
2025-12-24 16:55:46 +08:00
|
|
|
|
|
|
|
|
|
|
// 返回统一的重试耗尽错误响应
|
|
|
|
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": "upstream_error",
|
|
|
|
|
|
"message": "Upstream request failed after retries",
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
|
2025-12-24 16:55:46 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// streamingResult 流式响应结果
|
|
|
|
|
|
type streamingResult struct {
|
|
|
|
|
|
usage *ClaudeUsage
|
|
|
|
|
|
firstTokenMs *int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 更新5h窗口状态
|
|
|
|
|
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
|
|
|
|
|
|
|
|
|
|
|
// 设置SSE响应头
|
|
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
|
|
|
|
c.Header("Cache-Control", "no-cache")
|
|
|
|
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
|
|
c.Header("X-Accel-Buffering", "no")
|
|
|
|
|
|
|
|
|
|
|
|
// 透传其他响应头
|
|
|
|
|
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
|
|
|
|
|
c.Header("x-request-id", v)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
w := c.Writer
|
|
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return nil, errors.New("streaming not supported")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
usage := &ClaudeUsage{}
|
|
|
|
|
|
var firstTokenMs *int
|
|
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
|
|
|
// 设置更大的buffer以处理长行
|
|
|
|
|
|
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
|
|
|
|
|
|
|
|
|
|
|
needModelReplace := originalModel != mappedModel
|
|
|
|
|
|
|
|
|
|
|
|
for scanner.Scan() {
|
|
|
|
|
|
line := scanner.Text()
|
|
|
|
|
|
|
2025-12-26 03:49:55 -08:00
|
|
|
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
|
|
|
|
|
if sseDataRe.MatchString(line) {
|
|
|
|
|
|
data := sseDataRe.ReplaceAllString(line, "")
|
2025-12-18 13:50:39 +08:00
|
|
|
|
|
2025-12-26 03:49:55 -08:00
|
|
|
|
// 如果有模型映射,替换响应中的model字段
|
|
|
|
|
|
if needModelReplace {
|
|
|
|
|
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 转发行
|
|
|
|
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
|
|
|
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
flusher.Flush()
|
2025-12-18 13:50:39 +08:00
|
|
|
|
|
|
|
|
|
|
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
|
|
|
|
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
|
|
|
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|
|
|
|
|
firstTokenMs = &ms
|
|
|
|
|
|
}
|
|
|
|
|
|
s.parseSSEUsage(data, usage)
|
2025-12-26 03:49:55 -08:00
|
|
|
|
} else {
|
|
|
|
|
|
// 非 data 行直接转发
|
|
|
|
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
|
|
|
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
flusher.Flush()
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
|
|
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceModelInSSELine 替换SSE数据行中的model字段
|
|
|
|
|
|
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
2025-12-26 03:49:55 -08:00
|
|
|
|
if !sseDataRe.MatchString(line) {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
data := sseDataRe.ReplaceAllString(line, "")
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if data == "" || data == "[DONE]" {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 16:19:40 +08:00
|
|
|
|
var event map[string]any
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 只替换 message_start 事件中的 message.model
|
|
|
|
|
|
if event["type"] != "message_start" {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 16:19:40 +08:00
|
|
|
|
msg, ok := event["message"].(map[string]any)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if !ok {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
model, ok := msg["model"].(string)
|
|
|
|
|
|
if !ok || model != fromModel {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
msg["model"] = toModel
|
|
|
|
|
|
newData, err := json.Marshal(event)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return "data: " + string(newData)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
2025-12-23 16:53:53 +08:00
|
|
|
|
// 解析message_start获取input tokens(标准Claude API格式)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var msgStart struct {
|
|
|
|
|
|
Type string `json:"type"`
|
|
|
|
|
|
Message struct {
|
|
|
|
|
|
Usage ClaudeUsage `json:"usage"`
|
|
|
|
|
|
} `json:"message"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
|
|
|
|
|
|
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
|
|
|
|
|
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
|
|
|
|
|
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-23 16:53:53 +08:00
|
|
|
|
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
var msgDelta struct {
|
|
|
|
|
|
Type string `json:"type"`
|
|
|
|
|
|
Usage struct {
|
2025-12-23 16:53:53 +08:00
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
|
|
|
|
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
2025-12-18 13:50:39 +08:00
|
|
|
|
} `json:"usage"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
2025-12-23 16:53:53 +08:00
|
|
|
|
// output_tokens 总是从 message_delta 获取
|
2025-12-18 13:50:39 +08:00
|
|
|
|
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
2025-12-23 16:53:53 +08:00
|
|
|
|
|
|
|
|
|
|
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
|
|
|
|
|
|
if usage.InputTokens == 0 {
|
|
|
|
|
|
usage.InputTokens = msgDelta.Usage.InputTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
if usage.CacheCreationInputTokens == 0 {
|
|
|
|
|
|
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
if usage.CacheReadInputTokens == 0 {
|
|
|
|
|
|
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
|
|
|
|
|
}
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 更新5h窗口状态
|
|
|
|
|
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
|
|
|
|
|
|
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 解析usage
|
|
|
|
|
|
var response struct {
|
|
|
|
|
|
Usage ClaudeUsage `json:"usage"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(body, &response); err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("parse response: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 如果有模型映射,替换响应中的model字段
|
|
|
|
|
|
if originalModel != mappedModel {
|
|
|
|
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 透传响应头
|
|
|
|
|
|
for key, values := range resp.Header {
|
|
|
|
|
|
for _, value := range values {
|
|
|
|
|
|
c.Header(key, value)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 写入响应
|
|
|
|
|
|
c.Data(resp.StatusCode, "application/json", body)
|
|
|
|
|
|
|
|
|
|
|
|
return &response.Usage, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceModelInResponseBody 替换响应体中的model字段
|
|
|
|
|
|
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
2025-12-20 16:19:40 +08:00
|
|
|
|
var resp map[string]any
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
model, ok := resp["model"].(string)
|
|
|
|
|
|
if !ok || model != fromModel {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
resp["model"] = toModel
|
|
|
|
|
|
newBody, err := json.Marshal(resp)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return newBody
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// RecordUsageInput 记录使用量的输入参数
|
|
|
|
|
|
type RecordUsageInput struct {
|
|
|
|
|
|
Result *ForwardResult
|
2025-12-26 15:40:24 +08:00
|
|
|
|
ApiKey *ApiKey
|
|
|
|
|
|
User *User
|
|
|
|
|
|
Account *Account
|
|
|
|
|
|
Subscription *UserSubscription // 可选:订阅信息
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
|
|
|
|
|
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
|
|
|
|
|
result := input.Result
|
|
|
|
|
|
apiKey := input.ApiKey
|
|
|
|
|
|
user := input.User
|
|
|
|
|
|
account := input.Account
|
|
|
|
|
|
subscription := input.Subscription
|
|
|
|
|
|
|
|
|
|
|
|
// 计算费用
|
|
|
|
|
|
tokens := UsageTokens{
|
|
|
|
|
|
InputTokens: result.Usage.InputTokens,
|
|
|
|
|
|
OutputTokens: result.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 获取费率倍数
|
|
|
|
|
|
multiplier := s.cfg.Default.RateMultiplier
|
|
|
|
|
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
|
|
|
|
|
multiplier = apiKey.Group.RateMultiplier
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Printf("Calculate cost failed: %v", err)
|
|
|
|
|
|
// 使用默认费用继续
|
|
|
|
|
|
cost = &CostBreakdown{ActualCost: 0}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 判断计费方式:订阅模式 vs 余额模式
|
|
|
|
|
|
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
2025-12-26 15:40:24 +08:00
|
|
|
|
billingType := BillingTypeBalance
|
2025-12-18 13:50:39 +08:00
|
|
|
|
if isSubscriptionBilling {
|
2025-12-26 15:40:24 +08:00
|
|
|
|
billingType = BillingTypeSubscription
|
2025-12-18 13:50:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 创建使用日志
|
|
|
|
|
|
durationMs := int(result.Duration.Milliseconds())
|
2025-12-26 15:40:24 +08:00
|
|
|
|
usageLog := &UsageLog{
|
2025-12-18 13:50:39 +08:00
|
|
|
|
UserID: user.ID,
|
|
|
|
|
|
ApiKeyID: apiKey.ID,
|
|
|
|
|
|
AccountID: account.ID,
|
|
|
|
|
|
RequestID: result.RequestID,
|
|
|
|
|
|
Model: result.Model,
|
|
|
|
|
|
InputTokens: result.Usage.InputTokens,
|
|
|
|
|
|
OutputTokens: result.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
|
|
|
|
InputCost: cost.InputCost,
|
|
|
|
|
|
OutputCost: cost.OutputCost,
|
|
|
|
|
|
CacheCreationCost: cost.CacheCreationCost,
|
|
|
|
|
|
CacheReadCost: cost.CacheReadCost,
|
|
|
|
|
|
TotalCost: cost.TotalCost,
|
|
|
|
|
|
ActualCost: cost.ActualCost,
|
|
|
|
|
|
RateMultiplier: multiplier,
|
|
|
|
|
|
BillingType: billingType,
|
|
|
|
|
|
Stream: result.Stream,
|
|
|
|
|
|
DurationMs: &durationMs,
|
|
|
|
|
|
FirstTokenMs: result.FirstTokenMs,
|
|
|
|
|
|
CreatedAt: time.Now(),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 添加分组和订阅关联
|
|
|
|
|
|
if apiKey.GroupID != nil {
|
|
|
|
|
|
usageLog.GroupID = apiKey.GroupID
|
|
|
|
|
|
}
|
|
|
|
|
|
if subscription != nil {
|
|
|
|
|
|
usageLog.SubscriptionID = &subscription.ID
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-19 21:26:19 +08:00
|
|
|
|
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
log.Printf("Create usage log failed: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-29 03:17:25 +08:00
|
|
|
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|
|
|
|
|
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
|
|
|
|
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-18 13:50:39 +08:00
|
|
|
|
// 根据计费类型执行扣费
|
|
|
|
|
|
if isSubscriptionBilling {
|
|
|
|
|
|
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
|
|
|
|
|
if cost.TotalCost > 0 {
|
2025-12-19 21:26:19 +08:00
|
|
|
|
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
log.Printf("Increment subscription usage failed: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
// 异步更新订阅缓存
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
|
|
|
|
|
|
log.Printf("Update subscription cache failed: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
|
|
|
|
|
if cost.ActualCost > 0 {
|
2025-12-19 21:26:19 +08:00
|
|
|
|
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
2025-12-18 13:50:39 +08:00
|
|
|
|
log.Printf("Deduct balance failed: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
// 异步更新余额缓存
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
|
|
|
|
|
|
log.Printf("Update balance cache failed: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-28 08:07:15 +08:00
|
|
|
|
// Schedule batch update for account last_used_at
|
|
|
|
|
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
2025-12-18 13:50:39 +08:00
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2025-12-19 11:12:41 +08:00
|
|
|
|
|
|
|
|
|
|
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
|
|
|
|
|
// 特点:不记录使用量、仅支持非流式响应
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
2025-12-28 21:56:52 +08:00
|
|
|
|
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
|
|
|
|
|
// 参考 Antigravity-Manager 和 proxycast 实现
|
|
|
|
|
|
if account.Platform == PlatformAntigravity {
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-19 11:12:41 +08:00
|
|
|
|
// 应用模型映射(仅对 apikey 类型账号)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeApiKey {
|
2025-12-19 11:12:41 +08:00
|
|
|
|
var req struct {
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
|
|
|
|
|
mappedModel := account.GetMappedModel(req.Model)
|
|
|
|
|
|
if mappedModel != req.Model {
|
|
|
|
|
|
body = s.replaceModelInBody(body, mappedModel)
|
|
|
|
|
|
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 获取凭证
|
|
|
|
|
|
token, tokenType, err := s.GetAccessToken(ctx, account)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 构建上游请求
|
2025-12-20 11:56:11 +08:00
|
|
|
|
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
2025-12-19 11:12:41 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 11:56:11 +08:00
|
|
|
|
// 获取代理URL
|
|
|
|
|
|
proxyURL := ""
|
|
|
|
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
|
|
|
|
proxyURL = account.Proxy.URL()
|
2025-12-19 11:12:41 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
2025-12-22 22:58:31 +08:00
|
|
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
2025-12-19 11:12:41 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
|
|
|
|
|
return fmt.Errorf("upstream request failed: %w", err)
|
|
|
|
|
|
}
|
2025-12-20 15:29:52 +08:00
|
|
|
|
defer func() {
|
|
|
|
|
|
_ = resp.Body.Close()
|
|
|
|
|
|
}()
|
2025-12-19 11:12:41 +08:00
|
|
|
|
|
|
|
|
|
|
// 读取响应体
|
|
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 处理错误响应
|
|
|
|
|
|
if resp.StatusCode >= 400 {
|
|
|
|
|
|
// 标记账号状态(429/529等)
|
|
|
|
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
|
|
|
|
|
|
|
|
|
|
// 返回简化的错误响应
|
|
|
|
|
|
errMsg := "Upstream request failed"
|
|
|
|
|
|
switch resp.StatusCode {
|
|
|
|
|
|
case 429:
|
|
|
|
|
|
errMsg = "Rate limit exceeded"
|
|
|
|
|
|
case 529:
|
|
|
|
|
|
errMsg = "Service overloaded"
|
|
|
|
|
|
}
|
|
|
|
|
|
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
|
|
|
|
|
|
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 透传成功响应
|
|
|
|
|
|
c.Data(resp.StatusCode, "application/json", respBody)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// buildCountTokensRequest 构建 count_tokens 上游请求
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
2025-12-19 11:12:41 +08:00
|
|
|
|
// 确定目标 URL
|
|
|
|
|
|
targetURL := claudeAPICountTokensURL
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeApiKey {
|
2025-12-19 11:12:41 +08:00
|
|
|
|
baseURL := account.GetBaseURL()
|
|
|
|
|
|
targetURL = baseURL + "/v1/messages/count_tokens"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth 账号:应用统一指纹和重写 userID
|
|
|
|
|
|
if account.IsOAuth() && s.identityService != nil {
|
|
|
|
|
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
accountUUID := account.GetExtraString("account_uuid")
|
|
|
|
|
|
if accountUUID != "" && fp.ClientID != "" {
|
|
|
|
|
|
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
|
|
|
|
|
body = newBody
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 设置认证头
|
|
|
|
|
|
if tokenType == "oauth" {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
req.Header.Set("authorization", "Bearer "+token)
|
2025-12-19 11:12:41 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
req.Header.Set("x-api-key", token)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 白名单透传 headers
|
|
|
|
|
|
for key, values := range c.Request.Header {
|
|
|
|
|
|
lowerKey := strings.ToLower(key)
|
|
|
|
|
|
if allowedHeaders[lowerKey] {
|
|
|
|
|
|
for _, v := range values {
|
|
|
|
|
|
req.Header.Add(key, v)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth 账号:应用指纹到请求头
|
|
|
|
|
|
if account.IsOAuth() && s.identityService != nil {
|
|
|
|
|
|
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
|
|
|
|
|
if fp != nil {
|
|
|
|
|
|
s.identityService.ApplyFingerprint(req, fp)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 确保必要的 headers 存在
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if req.Header.Get("content-type") == "" {
|
|
|
|
|
|
req.Header.Set("content-type", "application/json")
|
2025-12-19 11:12:41 +08:00
|
|
|
|
}
|
|
|
|
|
|
if req.Header.Get("anthropic-version") == "" {
|
|
|
|
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OAuth 账号:处理 anthropic-beta header
|
|
|
|
|
|
if tokenType == "oauth" {
|
|
|
|
|
|
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-20 11:56:11 +08:00
|
|
|
|
return req, nil
|
2025-12-19 11:12:41 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// countTokensError 返回 count_tokens 错误响应
|
|
|
|
|
|
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
|
|
|
|
|
|
c.JSON(status, gin.H{
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": errType,
|
|
|
|
|
|
"message": message,
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|