mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-25 00:54:45 +08:00
Merge pull request #975 from Ylarod/aws-bedrock
sub2api: add bedrock support
This commit is contained in:
@@ -3370,6 +3370,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
||||
}
|
||||
if account.IsBedrock() {
|
||||
_, ok := ResolveBedrockModelID(account, requestedModel)
|
||||
return ok
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@@ -3527,6 +3531,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
|
||||
case AccountTypeBedrockAPIKey:
|
||||
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -3982,6 +3990,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
if account != nil && account.IsBedrock() {
|
||||
return s.forwardBedrock(ctx, c, account, parsed, startTime)
|
||||
}
|
||||
|
||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||
if account.Platform == PlatformAnthropic && c != nil {
|
||||
@@ -5123,6 +5135,366 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header,
|
||||
}
|
||||
}
|
||||
|
||||
// forwardBedrock 转发请求到 AWS Bedrock
|
||||
func (s *GatewayService) forwardBedrock(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
parsed *ParsedRequest,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
body := parsed.Body
|
||||
|
||||
region := bedrockRuntimeRegion(account)
|
||||
mappedModel, ok := ResolveBedrockModelID(account, reqModel)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel)
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||
}
|
||||
|
||||
betaHeader := ""
|
||||
if c != nil && c.Request != nil {
|
||||
betaHeader = c.GetHeader("anthropic-beta")
|
||||
}
|
||||
|
||||
// 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control)
|
||||
betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prepare bedrock request body: %w", err)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v",
|
||||
account.ID, account.Name, reqModel, mappedModel, reqStream)
|
||||
|
||||
// 根据账号类型选择认证方式
|
||||
var signer *BedrockSigner
|
||||
var bedrockAPIKey string
|
||||
if account.IsBedrockAPIKey() {
|
||||
bedrockAPIKey = account.GetCredential("api_key")
|
||||
if bedrockAPIKey == "" {
|
||||
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
|
||||
}
|
||||
} else {
|
||||
signer, err = NewBedrockSignerFromAccount(account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create bedrock signer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行上游请求(含重试)
|
||||
resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id,
|
||||
// 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。
|
||||
if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" {
|
||||
resp.Header.Set("x-request-id", awsReqID)
|
||||
}
|
||||
|
||||
// 错误/failover 处理
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleBedrockUpstreamErrors(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// 响应处理
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if usage == nil {
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-amzn-requestid"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑)
|
||||
func (s *GatewayService) executeBedrockUpstream(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
modelID string,
|
||||
region string,
|
||||
stream bool,
|
||||
signer *BedrockSigner,
|
||||
apiKey string,
|
||||
proxyURL string,
|
||||
) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
var upstreamReq *http.Request
|
||||
if account.IsBedrockAPIKey() {
|
||||
upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey)
|
||||
} else {
|
||||
upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false)
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetryAttempts {
|
||||
elapsed := time.Since(retryStart)
|
||||
if elapsed >= maxRetryElapsed {
|
||||
break
|
||||
}
|
||||
|
||||
delay := retryBackoffDelay(attempt)
|
||||
remaining := maxRetryElapsed - elapsed
|
||||
if delay > remaining {
|
||||
delay = remaining
|
||||
}
|
||||
if delay <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
Kind: "retry",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay)
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, errors.New("upstream request failed: empty response")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应)
|
||||
func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
) (*ForwardResult, error) {
|
||||
// retry exhausted + failover
|
||||
if s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
Kind: "retry_exhausted_failover",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// non-retryable failover
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
// other errors
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// buildUpstreamRequestBedrock 构建 Bedrock 上游请求
|
||||
func (s *GatewayService) buildUpstreamRequestBedrock(
|
||||
ctx context.Context,
|
||||
body []byte,
|
||||
modelID string,
|
||||
region string,
|
||||
stream bool,
|
||||
signer *BedrockSigner,
|
||||
) (*http.Request, error) {
|
||||
targetURL := BuildBedrockURL(region, modelID, stream)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// SigV4 签名
|
||||
if err := signer.SignRequest(ctx, req, body); err != nil {
|
||||
return nil, fmt.Errorf("sign bedrock request: %w", err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求
|
||||
func (s *GatewayService) buildUpstreamRequestBedrockAPIKey(
|
||||
ctx context.Context,
|
||||
body []byte,
|
||||
modelID string,
|
||||
region string,
|
||||
stream bool,
|
||||
apiKey string,
|
||||
) (*http.Request, error) {
|
||||
targetURL := BuildBedrockURL(region, modelID, stream)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应
|
||||
// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容
|
||||
func (s *GatewayService) handleBedrockNonStreamingResponse(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
) (*ClaudeUsage, error) {
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||
// 并移除该字段避免透传给客户端
|
||||
body = transformBedrockInvocationMetrics(body)
|
||||
|
||||
usage := parseClaudeUsageFromResponseBody(body)
|
||||
|
||||
c.Header("Content-Type", "application/json")
|
||||
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
@@ -5536,6 +5908,76 @@ func containsBetaToken(header, token string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string {
|
||||
if len(tokens) == 0 || len(filterSet) == 0 {
|
||||
return tokens
|
||||
}
|
||||
kept := make([]string, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
if _, filtered := filterSet[token]; !filtered {
|
||||
kept = append(kept, token)
|
||||
}
|
||||
}
|
||||
return kept
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
betaHeader string,
|
||||
body []byte,
|
||||
modelID string,
|
||||
) ([]string, error) {
|
||||
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
|
||||
policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
|
||||
if policy.blockErr != nil {
|
||||
return nil, policy.blockErr
|
||||
}
|
||||
|
||||
// 2. 解析 header + body 自动注入 + Bedrock 转换/过滤
|
||||
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||
|
||||
// 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。
|
||||
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
|
||||
// 如果不做此检查,block 规则会被绕过。
|
||||
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
|
||||
return nil, blockErr
|
||||
}
|
||||
|
||||
return filterBetaTokens(betaTokens, policy.filterSet), nil
|
||||
}
|
||||
|
||||
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
|
||||
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
|
||||
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
|
||||
if s.settingService == nil || len(tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
settings, err := s.settingService.GetBetaPolicySettings(ctx)
|
||||
if err != nil || settings == nil {
|
||||
return nil
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
tokenSet := buildBetaTokenSet(tokens)
|
||||
for _, rule := range settings.Rules {
|
||||
if rule.Action != BetaPolicyActionBlock {
|
||||
continue
|
||||
}
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
continue
|
||||
}
|
||||
if _, present := tokenSet[rule.BetaToken]; present {
|
||||
msg := rule.ErrorMessage
|
||||
if msg == "" {
|
||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||
}
|
||||
return &BetaBlockedError{Message: msg}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(tokens))
|
||||
for _, t := range tokens {
|
||||
@@ -7321,6 +7763,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||
}
|
||||
|
||||
// Bedrock 不支持 count_tokens 端点
|
||||
if account != nil && account.IsBedrock() {
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock")
|
||||
return nil
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user