2025-12-18 13:50:39 +08:00
package middleware
2025-12-29 16:52:55 +08:00
import (
"context"
2026-03-03 19:56:27 +08:00
"net/http"
2025-12-29 16:52:55 +08:00
2025-12-29 17:46:52 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
2026-03-03 19:56:27 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
2025-12-29 16:52:55 +08:00
"github.com/gin-gonic/gin"
)
2025-12-18 13:50:39 +08:00
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
2025-12-26 15:40:24 +08:00
// ContextKeyUserRole 当前用户角色( string)
ContextKeyUserRole ContextKey = "user_role"
2026-01-04 19:27:53 +08:00
// ContextKeyAPIKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key"
2025-12-18 13:50:39 +08:00
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
2025-12-29 16:52:55 +08:00
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
ContextKeyForcePlatform ContextKey = "force_platform"
2025-12-18 13:50:39 +08:00
)
2025-12-29 16:52:55 +08:00
// ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context( 供 Service 使用)和 gin.Context( 供 Handler 快速检查)
func ForcePlatform ( platform string ) gin . HandlerFunc {
return func ( c * gin . Context ) {
2025-12-29 17:46:52 +08:00
// 设置到 request.Context, 使用 ctxkey.ForcePlatform 供 Service 层读取
ctx := context . WithValue ( c . Request . Context ( ) , ctxkey . ForcePlatform , platform )
2025-12-29 16:52:55 +08:00
c . Request = c . Request . WithContext ( ctx )
// 同时设置到 gin.Context, 供 Handler 快速检查
c . Set ( string ( ContextKeyForcePlatform ) , platform )
c . Next ( )
}
}
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
func HasForcePlatform ( c * gin . Context ) bool {
_ , exists := c . Get ( string ( ContextKeyForcePlatform ) )
return exists
}
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
func GetForcePlatformFromContext ( c * gin . Context ) ( string , bool ) {
value , exists := c . Get ( string ( ContextKeyForcePlatform ) )
if ! exists {
return "" , false
}
platform , ok := value . ( string )
return platform , ok
}
2025-12-18 13:50:39 +08:00
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string ` json:"code" `
Message string ` json:"message" `
}
// NewErrorResponse 创建错误响应
func NewErrorResponse ( code , message string ) ErrorResponse {
return ErrorResponse {
Code : code ,
Message : message ,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError ( c * gin . Context , statusCode int , code , message string ) {
c . JSON ( statusCode , NewErrorResponse ( code , message ) )
c . Abort ( )
}
2026-03-03 19:56:27 +08:00
// ──────────────────────────────────────────────────────────
// RequireGroupAssignment — 未分组 Key 拦截中间件
// ──────────────────────────────────────────────────────────
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
type GatewayErrorWriter func ( c * gin . Context , status int , message string )
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
func AnthropicErrorWriter ( c * gin . Context , status int , message string ) {
c . JSON ( status , gin . H {
"type" : "error" ,
"error" : gin . H { "type" : "permission_error" , "message" : message } ,
} )
}
// GoogleErrorWriter 按 Google API 规范输出错误
func GoogleErrorWriter ( c * gin . Context , status int , message string ) {
c . JSON ( status , gin . H {
"error" : gin . H {
"code" : status ,
"message" : message ,
"status" : googleapi . HTTPStatusToGoogleStatus ( status ) ,
} ,
} )
}
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
func RequireGroupAssignment ( settingService * service . SettingService , writeError GatewayErrorWriter ) gin . HandlerFunc {
return func ( c * gin . Context ) {
apiKey , ok := GetAPIKeyFromContext ( c )
if ! ok || apiKey . GroupID != nil {
c . Next ( )
return
}
// 未分组 Key — 检查系统设置
if settingService . IsUngroupedKeySchedulingAllowed ( c . Request . Context ( ) ) {
c . Next ( )
return
}
writeError ( c , http . StatusForbidden , "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group." )
c . Abort ( )
}
}