mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 23:12:14 +08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 |
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
|||||||
parse_mode: "Markdown",
|
parse_mode: "Markdown",
|
||||||
disable_web_page_preview: true
|
disable_web_page_preview: true
|
||||||
}')"
|
}')"
|
||||||
|
|
||||||
|
sync-version-file:
|
||||||
|
needs: [release]
|
||||||
|
if: ${{ needs.release.result == 'success' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout default branch
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.repository.default_branch }}
|
||||||
|
|
||||||
|
- name: Sync VERSION file to released tag
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
VERSION=${{ github.event.inputs.tag }}
|
||||||
|
VERSION=${VERSION#v}
|
||||||
|
else
|
||||||
|
VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
fi
|
||||||
|
|
||||||
|
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||||
|
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||||
|
echo "VERSION file already matches $VERSION"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$VERSION" > backend/cmd/server/VERSION
|
||||||
|
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||||
|
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||||
|
|||||||
@@ -165,6 +165,8 @@ type AccountWithConcurrency struct {
|
|||||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||||
|
|
||||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(account),
|
Account: dto.AccountFromService(account),
|
||||||
@@ -226,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
var groupID int64
|
var groupID int64
|
||||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||||
|
groupID = service.AccountListGroupUngrouped
|
||||||
|
} else {
|
||||||
|
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parsedGroupID < 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = parsedGroupID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
@@ -1496,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsage handles getting account usage information
|
// GetUsage handles getting account usage information
|
||||||
// GET /api/v1/admin/accounts/:id/usage
|
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1504,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
source := c.DefaultQuery("source", "active")
|
||||||
|
|
||||||
|
var usage *service.UsageInfo
|
||||||
|
if source == "passive" {
|
||||||
|
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||||
|
} else {
|
||||||
|
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1219,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1227,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -593,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, message := mapGeminiUpstreamError(statusCode)
|
status, message := mapGeminiUpstreamError(statusCode)
|
||||||
googleError(c, status, message)
|
googleError(c, status, message)
|
||||||
|
|||||||
@@ -1435,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1443,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -484,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
|
||||||
var systemBlockFilterPrefixes = []string{
|
|
||||||
"x-anthropic-billing-header",
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
|
||||||
func filterSystemBlockByPrefix(text string) string {
|
|
||||||
for _, prefix := range systemBlockFilterPrefixes {
|
|
||||||
if strings.HasPrefix(text, prefix) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
filtered := filterOpenCodePrompt(sysStr)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
filtered := filterOpenCodePrompt(block.Text)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package antigravity
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
system json.RawMessage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claudeReq := &ClaudeRequest{
|
||||||
|
Model: "claude-3-5-sonnet-latest",
|
||||||
|
System: tt.system,
|
||||||
|
Messages: []ClaudeMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var req V1InternalRequest
|
||||||
|
require.NoError(t, json.Unmarshal(body, &req))
|
||||||
|
require.NotNil(t, req.Request.SystemInstruction)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, part := range req.Request.SystemInstruction.Parts {
|
||||||
|
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1008,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
|||||||
// Should default to image/png when media_type is empty.
|
// Should default to image/png when media_type is empty.
|
||||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// normalizeToolParameters tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNormalizeToolParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input json.RawMessage
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: json.RawMessage(``),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null input",
|
||||||
|
input: json.RawMessage(`null`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object without properties",
|
||||||
|
input: json.RawMessage(`{"type":"object"}`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with properties",
|
||||||
|
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||||
|
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-object type",
|
||||||
|
input: json.RawMessage(`{"type":"string"}`),
|
||||||
|
expected: `{"type":"string"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with additional fields preserved",
|
||||||
|
input: json.RawMessage(`{"type":"object","required":["name"]}`),
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON passthrough",
|
||||||
|
input: json.RawMessage(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeToolParameters(tt.input)
|
||||||
|
if tt.name == "invalid JSON passthrough" {
|
||||||
|
assert.Equal(t, tt.expected, string(result))
|
||||||
|
} else {
|
||||||
|
assert.JSONEq(t, tt.expected, string(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// Parameters must have "properties" field after normalization.
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.Contains(t, params, "properties")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "simple_tool", Description: "A tool"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||||
|
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||||
|
}
|
||||||
|
|||||||
@@ -409,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
|||||||
Type: "function",
|
Type: "function",
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: t.Description,
|
Description: t.Description,
|
||||||
Parameters: t.InputSchema,
|
Parameters: normalizeToolParameters(t.InputSchema),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||||
|
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||||
|
//
|
||||||
|
// - nil/empty → {"type":"object","properties":{}}
|
||||||
|
// - type=object without properties → adds "properties": {}
|
||||||
|
// - otherwise → returned unchanged
|
||||||
|
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
|
||||||
|
if len(schema) == 0 || string(schema) == "null" {
|
||||||
|
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(schema, &m); err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := m["type"]
|
||||||
|
if string(typ) != `"object"` {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["properties"]; ok {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
m["properties"] = json.RawMessage(`{}`)
|
||||||
|
out, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ var schedulerNeutralExtraKeyPrefixes = []string{
|
|||||||
"codex_secondary_",
|
"codex_secondary_",
|
||||||
"codex_5h_",
|
"codex_5h_",
|
||||||
"codex_7d_",
|
"codex_7d_",
|
||||||
|
"passive_usage_",
|
||||||
}
|
}
|
||||||
|
|
||||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
@@ -473,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
if groupID > 0 {
|
if groupID == service.AccountListGroupUngrouped {
|
||||||
|
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
|
||||||
|
} else if groupID > 0 {
|
||||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
accType string
|
accType string
|
||||||
status string
|
status string
|
||||||
search string
|
search string
|
||||||
|
groupID int64
|
||||||
wantCount int
|
wantCount int
|
||||||
validate func(accounts []service.Account)
|
validate func(accounts []service.Account)
|
||||||
}{
|
}{
|
||||||
@@ -265,6 +266,21 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
s.Require().Contains(accounts[0].Name, "alpha")
|
s.Require().Contains(accounts[0].Name, "alpha")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_ungrouped",
|
||||||
|
setup: func(client *dbent.Client) {
|
||||||
|
group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"})
|
||||||
|
grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"})
|
||||||
|
mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"})
|
||||||
|
mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1)
|
||||||
|
},
|
||||||
|
groupID: service.AccountListGroupUngrouped,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []service.Account) {
|
||||||
|
s.Require().Equal("ungrouped-account", accounts[0].Name)
|
||||||
|
s.Require().Empty(accounts[0].GroupIDs)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -277,7 +293,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
|
|
||||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(accounts, tt.wantCount)
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ var (
|
|||||||
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const AccountListGroupUngrouped int64 = -1
|
||||||
|
|
||||||
type AccountRepository interface {
|
type AccountRepository interface {
|
||||||
Create(ctx context.Context, account *Account) error
|
Create(ctx context.Context, account *Account) error
|
||||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||||
|
|||||||
@@ -308,7 +308,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
|
||||||
|
// 403 表示账号被上游封禁,标记为 error 状态
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.sendErrorAndEnd(c, errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process SSE stream
|
// Process SSE stream
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ type AICredit struct {
|
|||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
|
Source string `json:"source,omitempty"` // "passive" or "active"
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||||
@@ -393,6 +394,9 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||||
s.addWindowStats(ctx, account, usage)
|
s.addWindowStats(ctx, account, usage)
|
||||||
|
|
||||||
|
// 5. 将主动查询结果同步到被动缓存,下次 passive 加载即为最新值
|
||||||
|
s.syncActiveToPassive(ctx, account.ID, usage)
|
||||||
|
|
||||||
s.tryClearRecoverableAccountError(ctx, account)
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
@@ -409,6 +413,81 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPassiveUsage 从 Account.Extra 中的被动采样数据构建 UsageInfo,不调用外部 API。
|
||||||
|
// 仅适用于 Anthropic OAuth / SetupToken 账号。
|
||||||
|
func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||||
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get account failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复用 estimateSetupTokenUsage 构建 5h 窗口(OAuth 和 SetupToken 逻辑一致)
|
||||||
|
info := s.estimateSetupTokenUsage(account)
|
||||||
|
info.Source = "passive"
|
||||||
|
|
||||||
|
// 设置采样时间
|
||||||
|
if raw, ok := account.Extra["passive_usage_sampled_at"]; ok {
|
||||||
|
if str, ok := raw.(string); ok {
|
||||||
|
if t, err := time.Parse(time.RFC3339, str); err == nil {
|
||||||
|
info.UpdatedAt = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 7d 窗口(从被动采样数据)
|
||||||
|
util7d := parseExtraFloat64(account.Extra["passive_usage_7d_utilization"])
|
||||||
|
reset7dRaw := parseExtraFloat64(account.Extra["passive_usage_7d_reset"])
|
||||||
|
if util7d > 0 || reset7dRaw > 0 {
|
||||||
|
var resetAt *time.Time
|
||||||
|
var remaining int
|
||||||
|
if reset7dRaw > 0 {
|
||||||
|
t := time.Unix(int64(reset7dRaw), 0)
|
||||||
|
resetAt = &t
|
||||||
|
remaining = int(time.Until(t).Seconds())
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info.SevenDay = &UsageProgress{
|
||||||
|
Utilization: util7d * 100,
|
||||||
|
ResetsAt: resetAt,
|
||||||
|
RemainingSeconds: remaining,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加窗口统计
|
||||||
|
s.addWindowStats(ctx, account, info)
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncActiveToPassive 将主动查询的最新数据回写到 Extra 被动缓存,
|
||||||
|
// 这样下次被动加载时能看到最新值。
|
||||||
|
func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID int64, usage *UsageInfo) {
|
||||||
|
extraUpdates := make(map[string]any, 4)
|
||||||
|
|
||||||
|
if usage.FiveHour != nil {
|
||||||
|
extraUpdates["session_window_utilization"] = usage.FiveHour.Utilization / 100
|
||||||
|
}
|
||||||
|
if usage.SevenDay != nil {
|
||||||
|
extraUpdates["passive_usage_7d_utilization"] = usage.SevenDay.Utilization / 100
|
||||||
|
if usage.SevenDay.ResetsAt != nil {
|
||||||
|
extraUpdates["passive_usage_7d_reset"] = usage.SevenDay.ResetsAt.Unix()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(extraUpdates) > 0 {
|
||||||
|
extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
if err := s.accountRepo.UpdateExtra(ctx, accountID, extraUpdates); err != nil {
|
||||||
|
slog.Warn("sync_active_to_passive_failed", "account_id", accountID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
usage := &UsageInfo{UpdatedAt: &now}
|
usage := &UsageInfo{UpdatedAt: &now}
|
||||||
|
|||||||
@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
|
|||||||
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"x-request-id": []string{"rid-oauth-preserve"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Name: "anthropic-oauth-preserve",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
|
||||||
|
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
|
||||||
|
|
||||||
|
system := gjson.GetBytes(upstream.lastBody, "system")
|
||||||
|
require.True(t, system.Exists())
|
||||||
|
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
72
backend/internal/service/gateway_body_order_test.go
Normal file
72
backend/internal/service/gateway_body_order_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
last := -1
|
||||||
|
for _, token := range tokens {
|
||||||
|
pos := strings.Index(body, token)
|
||||||
|
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
|
||||||
|
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
|
||||||
|
last = pos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
|
||||||
|
|
||||||
|
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
|
||||||
|
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
|
||||||
|
|
||||||
|
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
|
||||||
|
injectMetadata: true,
|
||||||
|
metadataUserID: "user-1",
|
||||||
|
})
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||||
|
require.NotContains(t, resultStr, `"temperature"`)
|
||||||
|
require.NotContains(t, resultStr, `"tool_choice"`)
|
||||||
|
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
|
||||||
|
require.Contains(t, resultStr, `"tools":[]`)
|
||||||
|
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
|
||||||
|
|
||||||
|
result := injectClaudeCodePrompt(body, []any{
|
||||||
|
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
|
||||||
|
})
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||||
|
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
|
||||||
|
|
||||||
|
result := enforceCacheControlLimit(body)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||||
|
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||||
|
}
|
||||||
34
backend/internal/service/gateway_debug_env_test.go
Normal file
34
backend/internal/service/gateway_debug_env_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
|
||||||
|
t.Run("default disabled", func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, "")
|
||||||
|
if debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be disabled by default")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("enabled with true-like values", func(t *testing.T) {
|
||||||
|
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
|
||||||
|
t.Run(value, func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, value)
|
||||||
|
if !debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("disabled with other values", func(t *testing.T) {
|
||||||
|
for _, value := range []string{"0", "false", "off", "debug"} {
|
||||||
|
t.Run(value, func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, value)
|
||||||
|
if debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -51,6 +51,7 @@ const (
|
|||||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||||
defaultModelsListCacheTTL = 15 * time.Second
|
defaultModelsListCacheTTL = 15 * time.Second
|
||||||
postUsageBillingTimeout = 15 * time.Second
|
postUsageBillingTimeout = 15 * time.Second
|
||||||
|
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -339,12 +340,6 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
|
||||||
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
|
|
||||||
var systemBlockFilterPrefixes = []string{
|
|
||||||
"x-anthropic-billing-header",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNoAvailableAccounts 表示没有可用的账号
|
// ErrNoAvailableAccounts 表示没有可用的账号
|
||||||
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
||||||
|
|
||||||
@@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string {
|
|||||||
return strconv.FormatUint(h, 36)
|
return strconv.FormatUint(h, 36)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type anthropicCacheControlPayload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicSystemTextBlockPayload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicMetadataPayload struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
// replaceModelInBody 替换请求体中的model字段
|
// replaceModelInBody 替换请求体中的model字段
|
||||||
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
|
// 优先使用定点修改,尽量保持客户端原始字段顺序。
|
||||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||||
var req map[string]json.RawMessage
|
if len(body) == 0 {
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
// 只序列化 model 字段
|
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
||||||
modelBytes, err := json.Marshal(newModel)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
req["model"] = modelBytes
|
newBody, err := sjson.SetBytes(body, "model", newModel)
|
||||||
newBody, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
@@ -884,24 +889,146 @@ func sanitizeSystemText(text string) string {
|
|||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) {
|
||||||
blocks, ok := system.([]any)
|
block := anthropicSystemTextBlockPayload{
|
||||||
if !ok {
|
Type: "text",
|
||||||
return false
|
Text: text,
|
||||||
}
|
}
|
||||||
changed := false
|
if includeCacheControl {
|
||||||
for _, item := range blocks {
|
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
|
||||||
block, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, exists := block["cache_control"]; !exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(block, "cache_control")
|
|
||||||
changed = true
|
|
||||||
}
|
}
|
||||||
return changed
|
return json.Marshal(block)
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalAnthropicMetadata(userID string) ([]byte, error) {
|
||||||
|
return json.Marshal(anthropicMetadataPayload{UserID: userID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildJSONArrayRaw(items [][]byte) []byte {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return []byte("[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 2
|
||||||
|
for _, item := range items {
|
||||||
|
total += len(item)
|
||||||
|
}
|
||||||
|
total += len(items) - 1
|
||||||
|
|
||||||
|
buf := make([]byte, 0, total)
|
||||||
|
buf = append(buf, '[')
|
||||||
|
for i, item := range items {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
buf = append(buf, item...)
|
||||||
|
}
|
||||||
|
buf = append(buf, ']')
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) {
|
||||||
|
next, err := sjson.SetBytes(body, path, value)
|
||||||
|
if err != nil {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return next, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) {
|
||||||
|
next, err := sjson.SetRawBytes(body, path, raw)
|
||||||
|
if err != nil {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return next, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) {
|
||||||
|
next, err := sjson.DeleteBytes(body, path)
|
||||||
|
if err != nil {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return next, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) {
|
||||||
|
sys := gjson.GetBytes(body, "system")
|
||||||
|
if !sys.Exists() {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
|
||||||
|
out := body
|
||||||
|
modified := false
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case sys.Type == gjson.String:
|
||||||
|
sanitized := sanitizeSystemText(sys.String())
|
||||||
|
if sanitized != sys.String() {
|
||||||
|
if next, ok := setJSONValueBytes(out, "system", sanitized); ok {
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case sys.IsArray():
|
||||||
|
index := 0
|
||||||
|
sys.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("type").String() == "text" {
|
||||||
|
textResult := item.Get("text")
|
||||||
|
if textResult.Exists() && textResult.Type == gjson.String {
|
||||||
|
text := textResult.String()
|
||||||
|
sanitized := sanitizeSystemText(text)
|
||||||
|
if sanitized != text {
|
||||||
|
if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok {
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.stripSystemCacheControl && item.Get("cache_control").Exists() {
|
||||||
|
if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok {
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
index++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) {
|
||||||
|
if strings.TrimSpace(userID) == "" {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := gjson.GetBytes(body, "metadata")
|
||||||
|
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||||
|
raw, err := marshalAnthropicMetadata(userID)
|
||||||
|
if err != nil {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return setJSONRawBytes(body, "metadata", raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmedRaw := strings.TrimSpace(metadata.Raw)
|
||||||
|
if strings.HasPrefix(trimmedRaw, "{") {
|
||||||
|
existing := metadata.Get("user_id")
|
||||||
|
if existing.Exists() && existing.Type == gjson.String && existing.String() != "" {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return setJSONValueBytes(body, "metadata.user_id", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := marshalAnthropicMetadata(userID)
|
||||||
|
if err != nil {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
return setJSONRawBytes(body, "metadata", raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
|
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
|
||||||
@@ -909,96 +1036,59 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
return body, modelID
|
return body, modelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析为 map[string]any 用于修改字段
|
out := body
|
||||||
var req map[string]any
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
return body, modelID
|
|
||||||
}
|
|
||||||
|
|
||||||
modified := false
|
modified := false
|
||||||
|
|
||||||
if system, ok := req["system"]; ok {
|
if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed {
|
||||||
switch v := system.(type) {
|
out = next
|
||||||
case string:
|
modified = true
|
||||||
sanitized := sanitizeSystemText(v)
|
|
||||||
if sanitized != v {
|
|
||||||
req["system"] = sanitized
|
|
||||||
modified = true
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
for _, item := range v {
|
|
||||||
block, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if blockType, _ := block["type"].(string); blockType != "text" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
text, ok := block["text"].(string)
|
|
||||||
if !ok || text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sanitized := sanitizeSystemText(text)
|
|
||||||
if sanitized != text {
|
|
||||||
block["text"] = sanitized
|
|
||||||
modified = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rawModel, ok := req["model"].(string); ok {
|
rawModel := gjson.GetBytes(out, "model")
|
||||||
normalized := claude.NormalizeModelID(rawModel)
|
if rawModel.Exists() && rawModel.Type == gjson.String {
|
||||||
if normalized != rawModel {
|
normalized := claude.NormalizeModelID(rawModel.String())
|
||||||
req["model"] = normalized
|
if normalized != rawModel.String() {
|
||||||
|
if next, ok := setJSONValueBytes(out, "model", normalized); ok {
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
modelID = normalized
|
modelID = normalized
|
||||||
modified = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确保 tools 字段存在(即使为空数组)
|
// 确保 tools 字段存在(即使为空数组)
|
||||||
if _, exists := req["tools"]; !exists {
|
if !gjson.GetBytes(out, "tools").Exists() {
|
||||||
req["tools"] = []any{}
|
if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok {
|
||||||
modified = true
|
out = next
|
||||||
}
|
|
||||||
|
|
||||||
if opts.stripSystemCacheControl {
|
|
||||||
if system, ok := req["system"]; ok {
|
|
||||||
_ = stripCacheControlFromSystemBlocks(system)
|
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.injectMetadata && opts.metadataUserID != "" {
|
if opts.injectMetadata && opts.metadataUserID != "" {
|
||||||
metadata, ok := req["metadata"].(map[string]any)
|
if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed {
|
||||||
if !ok {
|
out = next
|
||||||
metadata = map[string]any{}
|
|
||||||
req["metadata"] = metadata
|
|
||||||
}
|
|
||||||
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
|
|
||||||
metadata["user_id"] = opts.metadataUserID
|
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, hasTemp := req["temperature"]; hasTemp {
|
if gjson.GetBytes(out, "temperature").Exists() {
|
||||||
delete(req, "temperature")
|
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
|
||||||
modified = true
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, hasChoice := req["tool_choice"]; hasChoice {
|
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||||
delete(req, "tool_choice")
|
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||||
modified = true
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !modified {
|
if !modified {
|
||||||
return body, modelID
|
return body, modelID
|
||||||
}
|
}
|
||||||
|
|
||||||
newBody, err := json.Marshal(req)
|
return out, modelID
|
||||||
if err != nil {
|
|
||||||
return body, modelID
|
|
||||||
}
|
|
||||||
return newBody, modelID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||||
@@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
|
|
||||||
func matchesFilterPrefix(text string) bool {
|
|
||||||
for _, prefix := range systemBlockFilterPrefixes {
|
|
||||||
if strings.HasPrefix(text, prefix) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
|
|
||||||
// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system)
|
|
||||||
func filterSystemBlocksByPrefix(body []byte) []byte {
|
|
||||||
sys := gjson.GetBytes(body, "system")
|
|
||||||
if !sys.Exists() {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case sys.Type == gjson.String:
|
|
||||||
if matchesFilterPrefix(sys.Str) {
|
|
||||||
result, err := sjson.DeleteBytes(body, "system")
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
case sys.IsArray():
|
|
||||||
var parsed []any
|
|
||||||
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
filtered := make([]any, 0, len(parsed))
|
|
||||||
changed := false
|
|
||||||
for _, item := range parsed {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
|
|
||||||
changed = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filtered = append(filtered, item)
|
|
||||||
}
|
|
||||||
if changed {
|
|
||||||
result, err := sjson.SetBytes(body, "system", filtered)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||||
// 处理 null、字符串、数组三种格式
|
// 处理 null、字符串、数组三种格式
|
||||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||||
claudeCodeBlock := map[string]any{
|
claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
|
||||||
"type": "text",
|
if err != nil {
|
||||||
"text": claudeCodeSystemPrompt,
|
logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
|
||||||
"cache_control": map[string]string{"type": "ephemeral"},
|
return body
|
||||||
}
|
}
|
||||||
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
|
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
|
||||||
// banner, it also prefixes the next system instruction with the same banner plus
|
// banner, it also prefixes the next system instruction with the same banner plus
|
||||||
// a blank line. This helps when upstream concatenates system instructions.
|
// a blank line. This helps when upstream concatenates system instructions.
|
||||||
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||||
|
|
||||||
var newSystem []any
|
var items [][]byte
|
||||||
|
|
||||||
switch v := system.(type) {
|
switch v := system.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
newSystem = []any{claudeCodeBlock}
|
items = [][]byte{claudeCodeBlock}
|
||||||
case string:
|
case string:
|
||||||
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
|
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
|
||||||
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||||
newSystem = []any{claudeCodeBlock}
|
items = [][]byte{claudeCodeBlock}
|
||||||
} else {
|
} else {
|
||||||
// Mirror opencode behavior: keep the banner as a separate system entry,
|
// Mirror opencode behavior: keep the banner as a separate system entry,
|
||||||
// but also prefix the next system text with the banner.
|
// but also prefix the next system text with the banner.
|
||||||
@@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
|||||||
if !strings.HasPrefix(v, claudeCodePrefix) {
|
if !strings.HasPrefix(v, claudeCodePrefix) {
|
||||||
merged = claudeCodePrefix + "\n\n" + v
|
merged = claudeCodePrefix + "\n\n" + v
|
||||||
}
|
}
|
||||||
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
|
nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false)
|
||||||
|
if buildErr != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
items = [][]byte{claudeCodeBlock, nextBlock}
|
||||||
}
|
}
|
||||||
case []any:
|
case []any:
|
||||||
newSystem = make([]any, 0, len(v)+1)
|
items = make([][]byte, 0, len(v)+1)
|
||||||
newSystem = append(newSystem, claudeCodeBlock)
|
items = append(items, claudeCodeBlock)
|
||||||
prefixedNext := false
|
prefixedNext := false
|
||||||
for _, item := range v {
|
systemResult := gjson.GetBytes(body, "system")
|
||||||
if m, ok := item.(map[string]any); ok {
|
if systemResult.IsArray() {
|
||||||
|
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
textResult := item.Get("text")
|
||||||
|
if textResult.Exists() && textResult.Type == gjson.String &&
|
||||||
|
strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := []byte(item.Raw)
|
||||||
|
// Prefix the first subsequent text system block once.
|
||||||
|
if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String {
|
||||||
|
text := textResult.String()
|
||||||
|
if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
|
||||||
|
next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text)
|
||||||
|
if setErr == nil {
|
||||||
|
raw = next
|
||||||
|
prefixedNext = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
items = append(items, raw)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for _, item := range v {
|
||||||
|
m, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
raw, marshalErr := json.Marshal(item)
|
||||||
|
if marshalErr == nil {
|
||||||
|
items = append(items, raw)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Prefix the first subsequent text system block once.
|
|
||||||
if !prefixedNext {
|
if !prefixedNext {
|
||||||
if blockType, _ := m["type"].(string); blockType == "text" {
|
if blockType, _ := m["type"].(string); blockType == "text" {
|
||||||
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
|
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
|
||||||
@@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
raw, marshalErr := json.Marshal(m)
|
||||||
|
if marshalErr == nil {
|
||||||
|
items = append(items, raw)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
newSystem = append(newSystem, item)
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
newSystem = []any{claudeCodeBlock}
|
items = [][]byte{claudeCodeBlock}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := sjson.SetBytes(body, "system", newSystem)
|
result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items))
|
||||||
if err != nil {
|
if !ok {
|
||||||
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err)
|
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt")
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cacheControlPath struct {
|
||||||
|
path string
|
||||||
|
log string
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) {
|
||||||
|
system := gjson.GetBytes(body, "system")
|
||||||
|
if system.IsArray() {
|
||||||
|
sysIndex := 0
|
||||||
|
system.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("cache_control").Exists() {
|
||||||
|
path := fmt.Sprintf("system.%d.cache_control", sysIndex)
|
||||||
|
if item.Get("type").String() == "thinking" {
|
||||||
|
invalidThinking = append(invalidThinking, cacheControlPath{
|
||||||
|
path: path,
|
||||||
|
log: "[Warning] Removed illegal cache_control from thinking block in system",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
systemPaths = append(systemPaths, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sysIndex++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if messages.IsArray() {
|
||||||
|
msgIndex := 0
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.IsArray() {
|
||||||
|
contentIndex := 0
|
||||||
|
content.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("cache_control").Exists() {
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex)
|
||||||
|
if item.Get("type").String() == "thinking" {
|
||||||
|
invalidThinking = append(invalidThinking, cacheControlPath{
|
||||||
|
path: path,
|
||||||
|
log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messagePaths = append(messagePaths, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentIndex++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
msgIndex++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return invalidThinking, messagePaths, systemPaths
|
||||||
|
}
|
||||||
|
|
||||||
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
||||||
func enforceCacheControlLimit(body []byte) []byte {
|
func enforceCacheControlLimit(body []byte) []byte {
|
||||||
var data map[string]any
|
if len(body) == 0 {
|
||||||
if err := json.Unmarshal(body, &data); err != nil {
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body)
|
||||||
removeCacheControlFromThinkingBlocks(data)
|
out := body
|
||||||
|
modified := false
|
||||||
|
|
||||||
// 计算当前 cache_control 块数量
|
// 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||||
count := countCacheControlBlocks(data)
|
for _, item := range invalidThinking {
|
||||||
|
if !gjson.GetBytes(out, item.path).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next, ok := deleteJSONPathBytes(out, item.path)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "%s", item.log)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := len(messagePaths) + len(systemPaths)
|
||||||
if count <= maxCacheControlBlocks {
|
if count <= maxCacheControlBlocks {
|
||||||
|
if modified {
|
||||||
|
return out
|
||||||
|
}
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// 超限:优先从 messages 中移除,再从 system 中移除
|
// 超限:优先从 messages 中移除,再从 system 中移除
|
||||||
for count > maxCacheControlBlocks {
|
remaining := count - maxCacheControlBlocks
|
||||||
if removeCacheControlFromMessages(data) {
|
for _, path := range messagePaths {
|
||||||
count--
|
if remaining <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(out, path).Exists() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if removeCacheControlFromSystem(data) {
|
next, ok := deleteJSONPathBytes(out, path)
|
||||||
count--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
|
||||||
// 注意:thinking 块不支持 cache_control,统计时跳过
|
|
||||||
func countCacheControlBlocks(data map[string]any) int {
|
|
||||||
count := 0
|
|
||||||
|
|
||||||
// 统计 system 中的块
|
|
||||||
if system, ok := data["system"].([]any); ok {
|
|
||||||
for _, item := range system {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
// thinking 块不支持 cache_control,跳过
|
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计 messages 中的块
|
|
||||||
if messages, ok := data["messages"].([]any); ok {
|
|
||||||
for _, msg := range messages {
|
|
||||||
if msgMap, ok := msg.(map[string]any); ok {
|
|
||||||
if content, ok := msgMap["content"].([]any); ok {
|
|
||||||
for _, item := range content {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
// thinking 块不支持 cache_control,跳过
|
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
|
||||||
// 注意:跳过 thinking 块(它不支持 cache_control)
|
|
||||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
|
||||||
messages, ok := data["messages"].([]any)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range messages {
|
|
||||||
msgMap, ok := msg.(map[string]any)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
content, ok := msgMap["content"].([]any)
|
out = next
|
||||||
|
modified = true
|
||||||
|
remaining--
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- {
|
||||||
|
path := systemPaths[i]
|
||||||
|
if !gjson.GetBytes(out, path).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next, ok := deleteJSONPathBytes(out, path)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, item := range content {
|
out = next
|
||||||
if m, ok := item.(map[string]any); ok {
|
modified = true
|
||||||
// thinking 块不支持 cache_control,跳过
|
remaining--
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
delete(m, "cache_control")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
|
||||||
// 注意:跳过 thinking 块(它不支持 cache_control)
|
|
||||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
|
||||||
system, ok := data["system"].([]any)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
if modified {
|
||||||
for i := len(system) - 1; i >= 0; i-- {
|
return out
|
||||||
if m, ok := system[i].(map[string]any); ok {
|
|
||||||
// thinking 块不支持 cache_control,跳过
|
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
delete(m, "cache_control")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
|
|
||||||
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
|
|
||||||
func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
|
||||||
// 清理 system 中的 thinking 块
|
|
||||||
if system, ok := data["system"].([]any); ok {
|
|
||||||
for _, item := range system {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
delete(m, "cache_control")
|
|
||||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理 messages 中的 thinking 块
|
|
||||||
if messages, ok := data["messages"].([]any); ok {
|
|
||||||
for msgIdx, msg := range messages {
|
|
||||||
if msgMap, ok := msg.(map[string]any); ok {
|
|
||||||
if content, ok := msgMap["content"].([]any); ok {
|
|
||||||
for contentIdx, item := range content {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
|
||||||
if _, has := m["cache_control"]; has {
|
|
||||||
delete(m, "cache_control")
|
|
||||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
@@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
reqStream := parsed.Stream
|
reqStream := parsed.Stream
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
|
||||||
|
// === DEBUG: 打印客户端原始请求 body ===
|
||||||
|
debugLogRequestBody("CLIENT_ORIGINAL", body)
|
||||||
|
|
||||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||||
|
|
||||||
@@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
|
|
||||||
// 放在 inject/normalize 之后,确保不会被覆盖
|
|
||||||
if account.IsOAuth() {
|
|
||||||
body = filterSystemBlocksByPrefix(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
body = enforceCacheControlLimit(body)
|
body = enforceCacheControlLimit(body)
|
||||||
|
|
||||||
@@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === DEBUG: 打印转发给上游的 body(metadata 已重写) ===
|
||||||
|
debugLogRequestBody("UPSTREAM_FORWARD", body)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool {
|
|||||||
usage["cache_read_input_tokens"] = cached
|
usage["cache_read_input_tokens"] = cached
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func debugGatewayBodyLoggingEnabled() bool {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv))
|
||||||
|
if raw == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(raw) {
|
||||||
|
case "1", "true", "yes", "on":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。
|
||||||
|
// 默认关闭,仅在设置环境变量时启用:
|
||||||
|
//
|
||||||
|
// SUB2API_DEBUG_GATEWAY_BODY=1
|
||||||
|
func debugLogRequestBody(tag string, body []byte) {
|
||||||
|
if !debugGatewayBodyLoggingEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 metadata 字段完整打印
|
||||||
|
metadataResult := gjson.GetBytes(body, "metadata")
|
||||||
|
if metadataResult.Exists() {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw)
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全量打印 body
|
||||||
|
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body))
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,6 +14,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
@@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RawMessage 保留其他字段的原始字节
|
metadata := gjson.GetBytes(body, "metadata")
|
||||||
var reqMap map[string]json.RawMessage
|
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
return body, nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 metadata 字段
|
userIDResult := metadata.Get("user_id")
|
||||||
metadataRaw, ok := reqMap["metadata"]
|
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||||
if !ok {
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
userID := userIDResult.String()
|
||||||
var metadata map[string]any
|
if userID == "" {
|
||||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
|
||||||
if !ok || userID == "" {
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
// 根据客户端版本选择输出格式
|
// 根据客户端版本选择输出格式
|
||||||
version := ExtractCLIVersion(fingerprintUA)
|
version := ExtractCLIVersion(fingerprintUA)
|
||||||
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
|
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
|
||||||
|
if newUserID == userID {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
|
||||||
|
|
||||||
// 只重新序列化 metadata 字段
|
|
||||||
newMetadataRaw, err := json.Marshal(metadata)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
reqMap["metadata"] = newMetadataRaw
|
return newBody, nil
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||||
@@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RawMessage 保留其他字段的原始字节
|
metadata := gjson.GetBytes(newBody, "metadata")
|
||||||
var reqMap map[string]json.RawMessage
|
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
return newBody, nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 metadata 字段
|
userIDResult := metadata.Get("user_id")
|
||||||
metadataRaw, ok := reqMap["metadata"]
|
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||||
if !ok {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
userID := userIDResult.String()
|
||||||
var metadata map[string]any
|
if userID == "" {
|
||||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
|
||||||
return newBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
|
||||||
if !ok || userID == "" {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
"after", newUserID,
|
"after", newUserID,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
if newUserID == userID {
|
||||||
|
|
||||||
// 只重新序列化 metadata 字段
|
|
||||||
newMetadataRaw, marshalErr := json.Marshal(metadata)
|
|
||||||
if marshalErr != nil {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
reqMap["metadata"] = newMetadataRaw
|
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
|
||||||
|
if setErr != nil {
|
||||||
|
return newBody, nil
|
||||||
|
}
|
||||||
|
return maskedBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||||
|
|||||||
82
backend/internal/service/identity_service_order_test.go
Normal file
82
backend/internal/service/identity_service_order_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type identityCacheStub struct {
|
||||||
|
maskedSessionID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
|
||||||
|
return s.maskedSessionID, nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
|
||||||
|
s.maskedSessionID = sessionID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
cache := &identityCacheStub{}
|
||||||
|
svc := NewIdentityService(cache)
|
||||||
|
|
||||||
|
originalUserID := FormatMetadataUserID(
|
||||||
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
|
"",
|
||||||
|
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||||
|
"2.1.78",
|
||||||
|
)
|
||||||
|
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||||
|
|
||||||
|
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||||
|
require.NotContains(t, resultStr, originalUserID)
|
||||||
|
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
|
||||||
|
svc := NewIdentityService(cache)
|
||||||
|
|
||||||
|
originalUserID := FormatMetadataUserID(
|
||||||
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
|
"",
|
||||||
|
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||||
|
"2.1.78",
|
||||||
|
)
|
||||||
|
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"session_id_masking_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||||
|
require.Contains(t, resultStr, cache.maskedSessionID)
|
||||||
|
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func strconvQuote(v string) string {
|
||||||
|
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
|
||||||
|
}
|
||||||
81
backend/internal/service/openai_compat_prompt_cache_key.go
Normal file
81
backend/internal/service/openai_compat_prompt_cache_key.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
const compatPromptCacheKeyPrefix = "compat_cc_"
|
||||||
|
|
||||||
|
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||||
|
switch normalizeCodexModel(strings.TrimSpace(model)) {
|
||||||
|
case "gpt-5.4", "gpt-5.3-codex":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string {
|
||||||
|
if req == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
|
||||||
|
if normalizedModel == "" {
|
||||||
|
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
|
||||||
|
}
|
||||||
|
if normalizedModel == "" {
|
||||||
|
normalizedModel = strings.TrimSpace(req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedParts := []string{"model=" + normalizedModel}
|
||||||
|
if req.ReasoningEffort != "" {
|
||||||
|
seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.ReasoningEffort))
|
||||||
|
}
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
if raw, err := json.Marshal(req.Tools); err == nil {
|
||||||
|
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(req.Functions) > 0 {
|
||||||
|
if raw, err := json.Marshal(req.Functions); err == nil {
|
||||||
|
seedParts = append(seedParts, "functions="+normalizeCompatSeedJSON(raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firstUserCaptured := false
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
switch strings.TrimSpace(msg.Role) {
|
||||||
|
case "system":
|
||||||
|
seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(msg.Content))
|
||||||
|
case "user":
|
||||||
|
if !firstUserCaptured {
|
||||||
|
seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content))
|
||||||
|
firstUserCaptured = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCompatSeedJSON(v json.RawMessage) string {
|
||||||
|
if len(v) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var tmp any
|
||||||
|
if err := json.Unmarshal(v, &tmp); err != nil {
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(tmp)
|
||||||
|
if err != nil {
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustRawJSON(t *testing.T, s string) json.RawMessage {
|
||||||
|
t.Helper()
|
||||||
|
return json.RawMessage(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
|
||||||
|
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
|
||||||
|
base := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
extended := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "system", Content: mustRawJSON(t, `"You are helpful."`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Hello"`)},
|
||||||
|
{Role: "assistant", Content: mustRawJSON(t, `"Hi there!"`)},
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"How are you?"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
k1 := deriveCompatPromptCacheKey(base, "gpt-5.4")
|
||||||
|
k2 := deriveCompatPromptCacheKey(extended, "gpt-5.4")
|
||||||
|
require.Equal(t, k1, k2, "cache key should be stable across later turns")
|
||||||
|
require.NotEmpty(t, k1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
|
||||||
|
req1 := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req2 := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Question B"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
k1 := deriveCompatPromptCacheKey(req1, "gpt-5.4")
|
||||||
|
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
|
||||||
|
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
|
||||||
|
}
|
||||||
@@ -43,23 +43,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
clientStream := chatReq.Stream
|
clientStream := chatReq.Stream
|
||||||
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||||||
|
|
||||||
// 2. Convert to Responses and forward
|
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||||||
|
// derive a stable seed from the final upstream model family.
|
||||||
|
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
|
|
||||||
|
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||||||
|
compatPromptCacheInjected := false
|
||||||
|
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) {
|
||||||
|
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel)
|
||||||
|
compatPromptCacheInjected = promptCacheKey != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Convert to Responses and forward
|
||||||
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||||||
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Model mapping
|
|
||||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
|
||||||
responsesReq.Model = mappedModel
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
logFields := []zap.Field{
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("original_model", originalModel),
|
zap.String("original_model", originalModel),
|
||||||
zap.String("mapped_model", mappedModel),
|
zap.String("mapped_model", mappedModel),
|
||||||
zap.Bool("stream", clientStream),
|
zap.Bool("stream", clientStream),
|
||||||
)
|
}
|
||||||
|
if compatPromptCacheInjected {
|
||||||
|
logFields = append(logFields,
|
||||||
|
zap.Bool("compat_prompt_cache_key_injected", true),
|
||||||
|
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||||||
|
|
||||||
// 4. Marshal Responses request body, then apply OAuth codex transform
|
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||||
responsesBody, err := json.Marshal(responsesReq)
|
responsesBody, err := json.Marshal(responsesReq)
|
||||||
|
|||||||
@@ -53,6 +53,13 @@ func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
|
|||||||
c.Set(key, value)
|
c.Set(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOpsUpstreamError is the exported wrapper for setOpsUpstreamError, used by
|
||||||
|
// handler-layer code (e.g. failover-exhausted paths) that needs to record the
|
||||||
|
// original upstream status code before mapping it to a client-facing code.
|
||||||
|
func SetOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||||
|
setOpsUpstreamError(c, upstreamStatusCode, upstreamMessage, upstreamDetail)
|
||||||
|
}
|
||||||
|
|
||||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1110,10 +1110,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
|||||||
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 窗口重置时清除旧的 utilization,避免残留上个窗口的数据
|
// 窗口重置时清除旧的 utilization 和被动采样数据,避免残留上个窗口的数据
|
||||||
if windowEnd != nil && needInitWindow {
|
if windowEnd != nil && needInitWindow {
|
||||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||||
"session_window_utilization": nil,
|
"session_window_utilization": nil,
|
||||||
|
"passive_usage_7d_utilization": nil,
|
||||||
|
"passive_usage_7d_reset": nil,
|
||||||
|
"passive_usage_sampled_at": nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1121,14 +1124,33 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
|||||||
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 存储真实的 utilization 值(0-1 小数),供 estimateSetupTokenUsage 使用
|
// 被动采样:从响应头收集 5h + 7d utilization,合并为一次 DB 写入
|
||||||
|
extraUpdates := make(map[string]any, 4)
|
||||||
|
// 5h utilization(0-1 小数),供 estimateSetupTokenUsage 使用
|
||||||
if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" {
|
if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" {
|
||||||
if util, err := strconv.ParseFloat(utilStr, 64); err == nil {
|
if util, err := strconv.ParseFloat(utilStr, 64); err == nil {
|
||||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
extraUpdates["session_window_utilization"] = util
|
||||||
"session_window_utilization": util,
|
}
|
||||||
}); err != nil {
|
}
|
||||||
slog.Warn("session_window_utilization_update_failed", "account_id", account.ID, "error", err)
|
// 7d utilization(0-1 小数)
|
||||||
|
if utilStr := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilStr != "" {
|
||||||
|
if util, err := strconv.ParseFloat(utilStr, 64); err == nil {
|
||||||
|
extraUpdates["passive_usage_7d_utilization"] = util
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 7d reset timestamp
|
||||||
|
if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" {
|
||||||
|
if ts, err := strconv.ParseInt(resetStr, 10, 64); err == nil {
|
||||||
|
if ts > 1e11 {
|
||||||
|
ts = ts / 1000
|
||||||
}
|
}
|
||||||
|
extraUpdates["passive_usage_7d_reset"] = ts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(extraUpdates) > 0 {
|
||||||
|
extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
if err := s.accountRepo.UpdateExtra(ctx, account.ID, extraUpdates); err != nil {
|
||||||
|
slog.Warn("passive_usage_update_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@lobehub/icons": "^4.0.2",
|
"@lobehub/icons": "^4.0.2",
|
||||||
|
"@tanstack/vue-virtual": "^3.13.23",
|
||||||
"@vueuse/core": "^10.7.0",
|
"@vueuse/core": "^10.7.0",
|
||||||
"axios": "^1.13.5",
|
"axios": "^1.13.5",
|
||||||
"chart.js": "^4.4.1",
|
"chart.js": "^4.4.1",
|
||||||
|
|||||||
18
frontend/pnpm-lock.yaml
generated
18
frontend/pnpm-lock.yaml
generated
@@ -11,6 +11,9 @@ importers:
|
|||||||
'@lobehub/icons':
|
'@lobehub/icons':
|
||||||
specifier: ^4.0.2
|
specifier: ^4.0.2
|
||||||
version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
|
version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
|
||||||
|
'@tanstack/vue-virtual':
|
||||||
|
specifier: ^3.13.23
|
||||||
|
version: 3.13.23(vue@3.5.26(typescript@5.6.3))
|
||||||
'@vueuse/core':
|
'@vueuse/core':
|
||||||
specifier: ^10.7.0
|
specifier: ^10.7.0
|
||||||
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
|
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
|
||||||
@@ -1376,6 +1379,14 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
react: '>= 16.3.0'
|
react: '>= 16.3.0'
|
||||||
|
|
||||||
|
'@tanstack/virtual-core@3.13.23':
|
||||||
|
resolution: {integrity: sha512-zSz2Z2HNyLjCplANTDyl3BcdQJc2k1+yyFoKhNRmCr7V7dY8o8q5m8uFTI1/Pg1kL+Hgrz6u3Xo6eFUB7l66cg==}
|
||||||
|
|
||||||
|
'@tanstack/vue-virtual@3.13.23':
|
||||||
|
resolution: {integrity: sha512-b5jPluAR6U3eOq6GWAYSpj3ugnAIZgGR0e6aGAgyRse0Yu6MVQQ0ZWm9SArSXWtageogn6bkVD8D//c4IjW3xQ==}
|
||||||
|
peerDependencies:
|
||||||
|
vue: ^2.7.0 || ^3.0.0
|
||||||
|
|
||||||
'@types/d3-array@3.2.2':
|
'@types/d3-array@3.2.2':
|
||||||
resolution: {integrity: sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==}
|
resolution: {integrity: sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==}
|
||||||
|
|
||||||
@@ -5808,6 +5819,13 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
react: 19.2.3
|
react: 19.2.3
|
||||||
|
|
||||||
|
'@tanstack/virtual-core@3.13.23': {}
|
||||||
|
|
||||||
|
'@tanstack/vue-virtual@3.13.23(vue@3.5.26(typescript@5.6.3))':
|
||||||
|
dependencies:
|
||||||
|
'@tanstack/virtual-core': 3.13.23
|
||||||
|
vue: 3.5.26(typescript@5.6.3)
|
||||||
|
|
||||||
'@types/d3-array@3.2.2': {}
|
'@types/d3-array@3.2.2': {}
|
||||||
|
|
||||||
'@types/d3-axis@3.0.6':
|
'@types/d3-axis@3.0.6':
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ export async function listWithEtag(
|
|||||||
platform?: string
|
platform?: string
|
||||||
type?: string
|
type?: string
|
||||||
status?: string
|
status?: string
|
||||||
|
group?: string
|
||||||
search?: string
|
search?: string
|
||||||
lite?: string
|
lite?: string
|
||||||
},
|
},
|
||||||
@@ -223,8 +224,10 @@ export async function clearError(id: number): Promise<Account> {
|
|||||||
* @param id - Account ID
|
* @param id - Account ID
|
||||||
* @returns Account usage info
|
* @returns Account usage info
|
||||||
*/
|
*/
|
||||||
export async function getUsage(id: number): Promise<AccountUsageInfo> {
|
export async function getUsage(id: number, source?: 'passive' | 'active'): Promise<AccountUsageInfo> {
|
||||||
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`)
|
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`, {
|
||||||
|
params: source ? { source } : undefined
|
||||||
|
})
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,38 @@
|
|||||||
:resets-at="usageInfo.seven_day_sonnet.resets_at"
|
:resets-at="usageInfo.seven_day_sonnet.resets_at"
|
||||||
color="purple"
|
color="purple"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- Passive sampling label + active query button -->
|
||||||
|
<div class="flex items-center gap-1.5 mt-0.5">
|
||||||
|
<span
|
||||||
|
v-if="usageInfo.source === 'passive'"
|
||||||
|
class="text-[9px] text-gray-400 dark:text-gray-500 italic"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.usageWindow.passiveSampled') }}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="inline-flex items-center gap-0.5 rounded px-1.5 py-0.5 text-[9px] font-medium text-blue-600 hover:bg-blue-50 dark:text-blue-400 dark:hover:bg-blue-900/30 transition-colors"
|
||||||
|
:disabled="activeQueryLoading"
|
||||||
|
@click="loadActiveUsage"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-2.5 w-2.5"
|
||||||
|
:class="{ 'animate-spin': activeQueryLoading }"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t('admin.accounts.usageWindow.activeQuery') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- No data yet -->
|
<!-- No data yet -->
|
||||||
@@ -433,6 +465,7 @@ const props = withDefaults(
|
|||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
|
const activeQueryLoading = ref(false)
|
||||||
const error = ref<string | null>(null)
|
const error = ref<string | null>(null)
|
||||||
const usageInfo = ref<AccountUsageInfo | null>(null)
|
const usageInfo = ref<AccountUsageInfo | null>(null)
|
||||||
|
|
||||||
@@ -888,14 +921,18 @@ const copyValidationURL = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loadUsage = async () => {
|
const isAnthropicOAuthOrSetupToken = computed(() => {
|
||||||
|
return props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')
|
||||||
|
})
|
||||||
|
|
||||||
|
const loadUsage = async (source?: 'passive' | 'active') => {
|
||||||
if (!shouldFetchUsage.value) return
|
if (!shouldFetchUsage.value) return
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
error.value = null
|
error.value = null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
usageInfo.value = await adminAPI.accounts.getUsage(props.account.id)
|
usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, source)
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
error.value = t('common.error')
|
error.value = t('common.error')
|
||||||
console.error('Failed to load usage:', e)
|
console.error('Failed to load usage:', e)
|
||||||
@@ -904,6 +941,17 @@ const loadUsage = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const loadActiveUsage = async () => {
|
||||||
|
activeQueryLoading.value = true
|
||||||
|
try {
|
||||||
|
usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active')
|
||||||
|
} catch (e: any) {
|
||||||
|
console.error('Failed to load active usage:', e)
|
||||||
|
} finally {
|
||||||
|
activeQueryLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ===== API Key quota progress bars =====
|
// ===== API Key quota progress bars =====
|
||||||
|
|
||||||
interface QuotaBarInfo {
|
interface QuotaBarInfo {
|
||||||
@@ -993,7 +1041,8 @@ const formatKeyUserCost = computed(() => {
|
|||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
if (!shouldAutoLoadUsageOnMount.value) return
|
if (!shouldAutoLoadUsageOnMount.value) return
|
||||||
loadUsage()
|
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
|
||||||
|
loadUsage(source)
|
||||||
})
|
})
|
||||||
|
|
||||||
watch(openAIUsageRefreshKey, (nextKey, prevKey) => {
|
watch(openAIUsageRefreshKey, (nextKey, prevKey) => {
|
||||||
@@ -1011,7 +1060,8 @@ watch(
|
|||||||
if (nextToken === prevToken) return
|
if (nextToken === prevToken) return
|
||||||
if (!shouldFetchUsage.value) return
|
if (!shouldFetchUsage.value) return
|
||||||
|
|
||||||
loadUsage().catch((e) => {
|
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
|
||||||
|
loadUsage(source).catch((e) => {
|
||||||
console.error('Failed to refresh usage after manual refresh:', e)
|
console.error('Failed to refresh usage after manual refresh:', e)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,5 +26,9 @@ const updateGroup = (value: string | number | boolean | null) => { emit('update:
|
|||||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
||||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
||||||
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
const gOpts = computed(() => [
|
||||||
|
{ value: '', label: t('admin.accounts.allGroups') },
|
||||||
|
{ value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') },
|
||||||
|
...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))
|
||||||
|
])
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ import { adminAPI } from '@/api/admin'
|
|||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import type { AnnouncementUserReadStatus } from '@/types'
|
import type { AnnouncementUserReadStatus } from '@/types'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
|
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import DataTable from '@/components/common/DataTable.vue'
|
import DataTable from '@/components/common/DataTable.vue'
|
||||||
@@ -92,7 +93,7 @@ const search = ref('')
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -147,28 +147,46 @@
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<!-- Data rows -->
|
<!-- Data rows (virtual scroll) -->
|
||||||
<tr
|
<template v-else>
|
||||||
v-else
|
<tr v-if="virtualPaddingTop > 0" aria-hidden="true">
|
||||||
v-for="(row, index) in sortedData"
|
<td :colspan="columns.length"
|
||||||
:key="resolveRowKey(row, index)"
|
:style="{ height: virtualPaddingTop + 'px', padding: 0, border: 'none' }">
|
||||||
:data-row-id="resolveRowKey(row, index)"
|
</td>
|
||||||
class="hover:bg-gray-50 dark:hover:bg-dark-800"
|
</tr>
|
||||||
>
|
<tr
|
||||||
<td
|
v-for="virtualRow in virtualItems"
|
||||||
v-for="(column, colIndex) in columns"
|
:key="resolveRowKey(sortedData[virtualRow.index], virtualRow.index)"
|
||||||
:key="column.key"
|
:data-row-id="resolveRowKey(sortedData[virtualRow.index], virtualRow.index)"
|
||||||
:class="[
|
:data-index="virtualRow.index"
|
||||||
'whitespace-nowrap py-4 text-sm text-gray-900 dark:text-gray-100',
|
:ref="measureElement"
|
||||||
getAdaptivePaddingClass(),
|
class="hover:bg-gray-50 dark:hover:bg-dark-800"
|
||||||
getStickyColumnClass(column, colIndex)
|
|
||||||
]"
|
|
||||||
>
|
>
|
||||||
<slot :name="`cell-${column.key}`" :row="row" :value="row[column.key]" :expanded="actionsExpanded">
|
<td
|
||||||
{{ column.formatter ? column.formatter(row[column.key], row) : row[column.key] }}
|
v-for="(column, colIndex) in columns"
|
||||||
</slot>
|
:key="column.key"
|
||||||
</td>
|
:class="[
|
||||||
</tr>
|
'whitespace-nowrap py-4 text-sm text-gray-900 dark:text-gray-100',
|
||||||
|
getAdaptivePaddingClass(),
|
||||||
|
getStickyColumnClass(column, colIndex)
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<slot :name="`cell-${column.key}`"
|
||||||
|
:row="sortedData[virtualRow.index]"
|
||||||
|
:value="sortedData[virtualRow.index][column.key]"
|
||||||
|
:expanded="actionsExpanded">
|
||||||
|
{{ column.formatter
|
||||||
|
? column.formatter(sortedData[virtualRow.index][column.key], sortedData[virtualRow.index])
|
||||||
|
: sortedData[virtualRow.index][column.key] }}
|
||||||
|
</slot>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr v-if="virtualPaddingBottom > 0" aria-hidden="true">
|
||||||
|
<td :colspan="columns.length"
|
||||||
|
:style="{ height: virtualPaddingBottom + 'px', padding: 0, border: 'none' }">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</template>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
</div>
|
</div>
|
||||||
@@ -176,6 +194,7 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed, ref, onMounted, onUnmounted, watch, nextTick } from 'vue'
|
import { computed, ref, onMounted, onUnmounted, watch, nextTick } from 'vue'
|
||||||
|
import { useVirtualizer } from '@tanstack/vue-virtual'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import type { Column } from './types'
|
import type { Column } from './types'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
@@ -299,6 +318,10 @@ interface Props {
|
|||||||
* will emit 'sort' events instead of performing client-side sorting.
|
* will emit 'sort' events instead of performing client-side sorting.
|
||||||
*/
|
*/
|
||||||
serverSideSort?: boolean
|
serverSideSort?: boolean
|
||||||
|
/** Estimated row height in px for the virtualizer (default 56) */
|
||||||
|
estimateRowHeight?: number
|
||||||
|
/** Number of rows to render beyond the visible area (default 5) */
|
||||||
|
overscan?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
@@ -499,6 +522,33 @@ const sortedData = computed(() => {
|
|||||||
.map(item => item.row)
|
.map(item => item.row)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// --- Virtual scrolling ---
|
||||||
|
const rowVirtualizer = useVirtualizer(computed(() => ({
|
||||||
|
count: sortedData.value?.length ?? 0,
|
||||||
|
getScrollElement: () => tableWrapperRef.value,
|
||||||
|
estimateSize: () => props.estimateRowHeight ?? 56,
|
||||||
|
overscan: props.overscan ?? 5,
|
||||||
|
})))
|
||||||
|
|
||||||
|
const virtualItems = computed(() => rowVirtualizer.value.getVirtualItems())
|
||||||
|
|
||||||
|
const virtualPaddingTop = computed(() => {
|
||||||
|
const items = virtualItems.value
|
||||||
|
return items.length > 0 ? items[0].start : 0
|
||||||
|
})
|
||||||
|
|
||||||
|
const virtualPaddingBottom = computed(() => {
|
||||||
|
const items = virtualItems.value
|
||||||
|
if (items.length === 0) return 0
|
||||||
|
return rowVirtualizer.value.getTotalSize() - items[items.length - 1].end
|
||||||
|
})
|
||||||
|
|
||||||
|
const measureElement = (el: any) => {
|
||||||
|
if (el) {
|
||||||
|
rowVirtualizer.value.measureElement(el as Element)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const hasActionsColumn = computed(() => {
|
const hasActionsColumn = computed(() => {
|
||||||
return props.columns.some(column => column.key === 'actions')
|
return props.columns.some(column => column.key === 'actions')
|
||||||
})
|
})
|
||||||
@@ -595,6 +645,13 @@ watch(
|
|||||||
},
|
},
|
||||||
{ flush: 'post' }
|
{ flush: 'post' }
|
||||||
)
|
)
|
||||||
|
|
||||||
|
defineExpose({
|
||||||
|
virtualizer: rowVirtualizer,
|
||||||
|
sortedData,
|
||||||
|
resolveRowKey,
|
||||||
|
tableWrapperEl: tableWrapperRef,
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
@@ -603,6 +660,9 @@ watch(
|
|||||||
--select-col-width: 52px; /* 勾选列宽度:px-6 (24px*2) + checkbox (16px) */
|
--select-col-width: 52px; /* 勾选列宽度:px-6 (24px*2) + checkbox (16px) */
|
||||||
position: relative;
|
position: relative;
|
||||||
overflow-x: auto;
|
overflow-x: auto;
|
||||||
|
overflow-y: auto;
|
||||||
|
flex: 1;
|
||||||
|
min-height: 0;
|
||||||
isolation: isolate;
|
isolation: isolate;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ import { computed, ref } from 'vue'
|
|||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import Select from './Select.vue'
|
import Select from './Select.vue'
|
||||||
|
import { setPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
@@ -216,6 +217,7 @@ const goToPage = (newPage: number) => {
|
|||||||
const handlePageSizeChange = (value: string | number | boolean | null) => {
|
const handlePageSizeChange = (value: string | number | boolean | null) => {
|
||||||
if (value === null || typeof value === 'boolean') return
|
if (value === null || typeof value === 'boolean') return
|
||||||
const newPageSize = typeof value === 'string' ? parseInt(value) : value
|
const newPageSize = typeof value === 'string' ? parseInt(value) : value
|
||||||
|
setPersistedPageSize(newPageSize)
|
||||||
emit('update:pageSize', newPageSize)
|
emit('update:pageSize', newPageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,7 @@
|
|||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import soraAPI, { type SoraGeneration } from '@/api/sora'
|
import soraAPI, { type SoraGeneration } from '@/api/sora'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import SoraMediaPreview from './SoraMediaPreview.vue'
|
import SoraMediaPreview from './SoraMediaPreview.vue'
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
@@ -190,7 +191,7 @@ async function loadItems(pageNum: number) {
|
|||||||
status: 'completed',
|
status: 'completed',
|
||||||
storage_type: 's3,local',
|
storage_type: 's3,local',
|
||||||
page: pageNum,
|
page: pageNum,
|
||||||
page_size: 20
|
page_size: getPersistedPageSize()
|
||||||
})
|
})
|
||||||
const rows = Array.isArray(res.data) ? res.data : []
|
const rows = Array.isArray(res.data) ? res.data : []
|
||||||
if (pageNum === 1) {
|
if (pageNum === 1) {
|
||||||
|
|||||||
27
frontend/src/composables/usePersistedPageSize.ts
Normal file
27
frontend/src/composables/usePersistedPageSize.ts
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
const STORAGE_KEY = 'table-page-size'
|
||||||
|
const DEFAULT_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 localStorage 读取/写入 pageSize
|
||||||
|
* 全局共享一个 key,所有表格统一偏好
|
||||||
|
*/
|
||||||
|
export function getPersistedPageSize(fallback = DEFAULT_PAGE_SIZE): number {
|
||||||
|
try {
|
||||||
|
const stored = localStorage.getItem(STORAGE_KEY)
|
||||||
|
if (stored) {
|
||||||
|
const parsed = Number(stored)
|
||||||
|
if (Number.isFinite(parsed) && parsed > 0) return parsed
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// localStorage 不可用(隐私模式等)
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
export function setPersistedPageSize(size: number): void {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(STORAGE_KEY, String(size))
|
||||||
|
} catch {
|
||||||
|
// 静默失败
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import { ref, onMounted, onUnmounted, type Ref } from 'vue'
|
import { ref, onMounted, onUnmounted, type Ref } from 'vue'
|
||||||
|
import type { Virtualizer } from '@tanstack/vue-virtual'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* WeChat-style swipe/drag to select rows in a DataTable,
|
* WeChat-style swipe/drag to select rows in a DataTable,
|
||||||
@@ -25,11 +26,22 @@ export interface SwipeSelectAdapter {
|
|||||||
isSelected: (id: number) => boolean
|
isSelected: (id: number) => boolean
|
||||||
select: (id: number) => void
|
select: (id: number) => void
|
||||||
deselect: (id: number) => void
|
deselect: (id: number) => void
|
||||||
|
batchUpdate?: (updater: (draft: Set<number>) => void) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SwipeSelectVirtualContext {
|
||||||
|
/** Get the virtualizer instance */
|
||||||
|
getVirtualizer: () => Virtualizer<HTMLElement, Element> | null
|
||||||
|
/** Get all sorted data */
|
||||||
|
getSortedData: () => any[]
|
||||||
|
/** Get row ID from data row */
|
||||||
|
getRowId: (row: any, index: number) => number
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useSwipeSelect(
|
export function useSwipeSelect(
|
||||||
containerRef: Ref<HTMLElement | null>,
|
containerRef: Ref<HTMLElement | null>,
|
||||||
adapter: SwipeSelectAdapter
|
adapter: SwipeSelectAdapter,
|
||||||
|
virtualContext?: SwipeSelectVirtualContext
|
||||||
) {
|
) {
|
||||||
const isDragging = ref(false)
|
const isDragging = ref(false)
|
||||||
|
|
||||||
@@ -95,6 +107,32 @@ export function useSwipeSelect(
|
|||||||
return (clientY - rHi.bottom < rLo.top - clientY) ? hi : lo
|
return (clientY - rHi.bottom < rLo.top - clientY) ? hi : lo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Virtual mode: find row index from Y coordinate using virtualizer data */
|
||||||
|
function findRowIndexAtYVirtual(clientY: number): number {
|
||||||
|
const virt = virtualContext!.getVirtualizer()
|
||||||
|
if (!virt) return -1
|
||||||
|
const scrollEl = virt.scrollElement
|
||||||
|
if (!scrollEl) return -1
|
||||||
|
|
||||||
|
const scrollRect = scrollEl.getBoundingClientRect()
|
||||||
|
const thead = scrollEl.querySelector('thead')
|
||||||
|
const theadHeight = thead ? thead.getBoundingClientRect().height : 0
|
||||||
|
const contentY = clientY - scrollRect.top - theadHeight + scrollEl.scrollTop
|
||||||
|
|
||||||
|
// Search in rendered virtualItems first
|
||||||
|
const items = virt.getVirtualItems()
|
||||||
|
for (const item of items) {
|
||||||
|
if (contentY >= item.start && contentY < item.end) return item.index
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outside visible range: estimate
|
||||||
|
const totalCount = virtualContext!.getSortedData().length
|
||||||
|
if (totalCount === 0) return -1
|
||||||
|
const est = virt.options.estimateSize(0)
|
||||||
|
const guess = Math.floor(contentY / est)
|
||||||
|
return Math.max(0, Math.min(totalCount - 1, guess))
|
||||||
|
}
|
||||||
|
|
||||||
// --- Prevent text selection via selectstart (no body style mutation) ---
|
// --- Prevent text selection via selectstart (no body style mutation) ---
|
||||||
function onSelectStart(e: Event) { e.preventDefault() }
|
function onSelectStart(e: Event) { e.preventDefault() }
|
||||||
|
|
||||||
@@ -140,16 +178,68 @@ export function useSwipeSelect(
|
|||||||
const lo = Math.min(rangeMin, prevMin)
|
const lo = Math.min(rangeMin, prevMin)
|
||||||
const hi = Math.max(rangeMax, prevMax)
|
const hi = Math.max(rangeMax, prevMax)
|
||||||
|
|
||||||
for (let i = lo; i <= hi && i < cachedRows.length; i++) {
|
if (adapter.batchUpdate) {
|
||||||
const id = getRowId(cachedRows[i])
|
adapter.batchUpdate((draft) => {
|
||||||
if (id === null) continue
|
for (let i = lo; i <= hi && i < cachedRows.length; i++) {
|
||||||
if (i >= rangeMin && i <= rangeMax) {
|
const id = getRowId(cachedRows[i])
|
||||||
if (dragMode === 'select') adapter.select(id)
|
if (id === null) continue
|
||||||
else adapter.deselect(id)
|
const shouldBeSelected = (i >= rangeMin && i <= rangeMax)
|
||||||
} else {
|
? (dragMode === 'select')
|
||||||
const wasSelected = initialSelectedSnapshot.get(id) ?? false
|
: (initialSelectedSnapshot.get(id) ?? false)
|
||||||
if (wasSelected) adapter.select(id)
|
if (shouldBeSelected) draft.add(id)
|
||||||
else adapter.deselect(id)
|
else draft.delete(id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for (let i = lo; i <= hi && i < cachedRows.length; i++) {
|
||||||
|
const id = getRowId(cachedRows[i])
|
||||||
|
if (id === null) continue
|
||||||
|
if (i >= rangeMin && i <= rangeMax) {
|
||||||
|
if (dragMode === 'select') adapter.select(id)
|
||||||
|
else adapter.deselect(id)
|
||||||
|
} else {
|
||||||
|
const wasSelected = initialSelectedSnapshot.get(id) ?? false
|
||||||
|
if (wasSelected) adapter.select(id)
|
||||||
|
else adapter.deselect(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastEndIndex = endIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Virtual mode: apply selection range using data array instead of DOM */
|
||||||
|
function applyRangeVirtual(endIndex: number) {
|
||||||
|
if (startRowIndex < 0 || endIndex < 0) return
|
||||||
|
const rangeMin = Math.min(startRowIndex, endIndex)
|
||||||
|
const rangeMax = Math.max(startRowIndex, endIndex)
|
||||||
|
const prevMin = lastEndIndex >= 0 ? Math.min(startRowIndex, lastEndIndex) : rangeMin
|
||||||
|
const prevMax = lastEndIndex >= 0 ? Math.max(startRowIndex, lastEndIndex) : rangeMax
|
||||||
|
const lo = Math.min(rangeMin, prevMin)
|
||||||
|
const hi = Math.max(rangeMax, prevMax)
|
||||||
|
const data = virtualContext!.getSortedData()
|
||||||
|
|
||||||
|
if (adapter.batchUpdate) {
|
||||||
|
adapter.batchUpdate((draft) => {
|
||||||
|
for (let i = lo; i <= hi && i < data.length; i++) {
|
||||||
|
const id = virtualContext!.getRowId(data[i], i)
|
||||||
|
const shouldBeSelected = (i >= rangeMin && i <= rangeMax)
|
||||||
|
? (dragMode === 'select')
|
||||||
|
: (initialSelectedSnapshot.get(id) ?? false)
|
||||||
|
if (shouldBeSelected) draft.add(id)
|
||||||
|
else draft.delete(id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for (let i = lo; i <= hi && i < data.length; i++) {
|
||||||
|
const id = virtualContext!.getRowId(data[i], i)
|
||||||
|
if (i >= rangeMin && i <= rangeMax) {
|
||||||
|
if (dragMode === 'select') adapter.select(id)
|
||||||
|
else adapter.deselect(id)
|
||||||
|
} else {
|
||||||
|
const wasSelected = initialSelectedSnapshot.get(id) ?? false
|
||||||
|
if (wasSelected) adapter.select(id)
|
||||||
|
else adapter.deselect(id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lastEndIndex = endIndex
|
lastEndIndex = endIndex
|
||||||
@@ -234,8 +324,14 @@ export function useSwipeSelect(
|
|||||||
if (shouldPreferNativeTextSelection(target)) return
|
if (shouldPreferNativeTextSelection(target)) return
|
||||||
if (shouldPreferNativeSelectionOutsideRows(target)) return
|
if (shouldPreferNativeSelectionOutsideRows(target)) return
|
||||||
|
|
||||||
cachedRows = getDataRows()
|
if (virtualContext) {
|
||||||
if (cachedRows.length === 0) return
|
// Virtual mode: check data availability instead of DOM rows
|
||||||
|
const data = virtualContext.getSortedData()
|
||||||
|
if (data.length === 0) return
|
||||||
|
} else {
|
||||||
|
cachedRows = getDataRows()
|
||||||
|
if (cachedRows.length === 0) return
|
||||||
|
}
|
||||||
|
|
||||||
pendingStartY = e.clientY
|
pendingStartY = e.clientY
|
||||||
// Prevent text selection as soon as the mouse is down,
|
// Prevent text selection as soon as the mouse is down,
|
||||||
@@ -253,13 +349,19 @@ export function useSwipeSelect(
|
|||||||
document.removeEventListener('mousemove', onThresholdMove)
|
document.removeEventListener('mousemove', onThresholdMove)
|
||||||
document.removeEventListener('mouseup', onThresholdUp)
|
document.removeEventListener('mouseup', onThresholdUp)
|
||||||
|
|
||||||
beginDrag(pendingStartY)
|
if (virtualContext) {
|
||||||
|
beginDragVirtual(pendingStartY)
|
||||||
|
} else {
|
||||||
|
beginDrag(pendingStartY)
|
||||||
|
}
|
||||||
|
|
||||||
// Process the move that crossed the threshold
|
// Process the move that crossed the threshold
|
||||||
lastMouseY = e.clientY
|
lastMouseY = e.clientY
|
||||||
updateMarquee(e.clientY)
|
updateMarquee(e.clientY)
|
||||||
const rowIdx = findRowIndexAtY(e.clientY)
|
const findIdx = virtualContext ? findRowIndexAtYVirtual : findRowIndexAtY
|
||||||
if (rowIdx >= 0) applyRange(rowIdx)
|
const apply = virtualContext ? applyRangeVirtual : applyRange
|
||||||
|
const rowIdx = findIdx(e.clientY)
|
||||||
|
if (rowIdx >= 0) apply(rowIdx)
|
||||||
autoScroll(e)
|
autoScroll(e)
|
||||||
|
|
||||||
document.addEventListener('mousemove', onMouseMove)
|
document.addEventListener('mousemove', onMouseMove)
|
||||||
@@ -306,22 +408,62 @@ export function useSwipeSelect(
|
|||||||
window.getSelection()?.removeAllRanges()
|
window.getSelection()?.removeAllRanges()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Virtual mode: begin drag using data array */
|
||||||
|
function beginDragVirtual(clientY: number) {
|
||||||
|
startRowIndex = findRowIndexAtYVirtual(clientY)
|
||||||
|
const data = virtualContext!.getSortedData()
|
||||||
|
const startRowId = startRowIndex >= 0 && startRowIndex < data.length
|
||||||
|
? virtualContext!.getRowId(data[startRowIndex], startRowIndex)
|
||||||
|
: null
|
||||||
|
dragMode = (startRowId !== null && adapter.isSelected(startRowId)) ? 'deselect' : 'select'
|
||||||
|
|
||||||
|
// Build full snapshot from all data rows
|
||||||
|
initialSelectedSnapshot = new Map()
|
||||||
|
for (let i = 0; i < data.length; i++) {
|
||||||
|
const id = virtualContext!.getRowId(data[i], i)
|
||||||
|
initialSelectedSnapshot.set(id, adapter.isSelected(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
isDragging.value = true
|
||||||
|
startY = clientY
|
||||||
|
lastMouseY = clientY
|
||||||
|
lastEndIndex = -1
|
||||||
|
|
||||||
|
// In virtual mode, scroll parent is the virtualizer's scroll element
|
||||||
|
const virt = virtualContext!.getVirtualizer()
|
||||||
|
cachedScrollParent = virt?.scrollElement ?? (containerRef.value ? getScrollParent(containerRef.value) : null)
|
||||||
|
|
||||||
|
createMarquee()
|
||||||
|
updateMarquee(clientY)
|
||||||
|
applyRangeVirtual(startRowIndex)
|
||||||
|
window.getSelection()?.removeAllRanges()
|
||||||
|
}
|
||||||
|
|
||||||
|
let moveRAF = 0
|
||||||
|
|
||||||
function onMouseMove(e: MouseEvent) {
|
function onMouseMove(e: MouseEvent) {
|
||||||
if (!isDragging.value) return
|
if (!isDragging.value) return
|
||||||
lastMouseY = e.clientY
|
lastMouseY = e.clientY
|
||||||
updateMarquee(e.clientY)
|
const findIdx = virtualContext ? findRowIndexAtYVirtual : findRowIndexAtY
|
||||||
const rowIdx = findRowIndexAtY(e.clientY)
|
const apply = virtualContext ? applyRangeVirtual : applyRange
|
||||||
if (rowIdx >= 0 && rowIdx !== lastEndIndex) applyRange(rowIdx)
|
cancelAnimationFrame(moveRAF)
|
||||||
|
moveRAF = requestAnimationFrame(() => {
|
||||||
|
updateMarquee(lastMouseY)
|
||||||
|
const rowIdx = findIdx(lastMouseY)
|
||||||
|
if (rowIdx >= 0 && rowIdx !== lastEndIndex) apply(rowIdx)
|
||||||
|
})
|
||||||
autoScroll(e)
|
autoScroll(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
function onWheel() {
|
function onWheel() {
|
||||||
if (!isDragging.value) return
|
if (!isDragging.value) return
|
||||||
|
const findIdx = virtualContext ? findRowIndexAtYVirtual : findRowIndexAtY
|
||||||
|
const apply = virtualContext ? applyRangeVirtual : applyRange
|
||||||
// After wheel scroll, rows shift in viewport — re-check selection
|
// After wheel scroll, rows shift in viewport — re-check selection
|
||||||
requestAnimationFrame(() => {
|
requestAnimationFrame(() => {
|
||||||
if (!isDragging.value) return // guard: drag may have ended before this frame
|
if (!isDragging.value) return // guard: drag may have ended before this frame
|
||||||
const rowIdx = findRowIndexAtY(lastMouseY)
|
const rowIdx = findIdx(lastMouseY)
|
||||||
if (rowIdx >= 0) applyRange(rowIdx)
|
if (rowIdx >= 0) apply(rowIdx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,6 +474,7 @@ export function useSwipeSelect(
|
|||||||
cachedRows = []
|
cachedRows = []
|
||||||
initialSelectedSnapshot.clear()
|
initialSelectedSnapshot.clear()
|
||||||
cachedScrollParent = null
|
cachedScrollParent = null
|
||||||
|
cancelAnimationFrame(moveRAF)
|
||||||
stopAutoScroll()
|
stopAutoScroll()
|
||||||
removeMarquee()
|
removeMarquee()
|
||||||
document.removeEventListener('selectstart', onSelectStart)
|
document.removeEventListener('selectstart', onSelectStart)
|
||||||
@@ -372,13 +515,15 @@ export function useSwipeSelect(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (dy !== 0) {
|
if (dy !== 0) {
|
||||||
|
const findIdx = virtualContext ? findRowIndexAtYVirtual : findRowIndexAtY
|
||||||
|
const apply = virtualContext ? applyRangeVirtual : applyRange
|
||||||
const step = () => {
|
const step = () => {
|
||||||
const prevScrollTop = scrollEl.scrollTop
|
const prevScrollTop = scrollEl.scrollTop
|
||||||
scrollEl.scrollTop += dy
|
scrollEl.scrollTop += dy
|
||||||
// Only re-check selection if scroll actually moved
|
// Only re-check selection if scroll actually moved
|
||||||
if (scrollEl.scrollTop !== prevScrollTop) {
|
if (scrollEl.scrollTop !== prevScrollTop) {
|
||||||
const rowIdx = findRowIndexAtY(lastMouseY)
|
const rowIdx = findIdx(lastMouseY)
|
||||||
if (rowIdx >= 0 && rowIdx !== lastEndIndex) applyRange(rowIdx)
|
if (rowIdx >= 0 && rowIdx !== lastEndIndex) apply(rowIdx)
|
||||||
}
|
}
|
||||||
scrollRAF = requestAnimationFrame(step)
|
scrollRAF = requestAnimationFrame(step)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { ref, reactive, onUnmounted, toRaw } from 'vue'
|
import { ref, reactive, onUnmounted, toRaw } from 'vue'
|
||||||
import { useDebounceFn } from '@vueuse/core'
|
import { useDebounceFn } from '@vueuse/core'
|
||||||
import type { BasePaginationResponse, FetchOptions } from '@/types'
|
import type { BasePaginationResponse, FetchOptions } from '@/types'
|
||||||
|
import { getPersistedPageSize, setPersistedPageSize } from './usePersistedPageSize'
|
||||||
|
|
||||||
interface PaginationState {
|
interface PaginationState {
|
||||||
page: number
|
page: number
|
||||||
@@ -21,14 +22,14 @@ interface TableLoaderOptions<T, P> {
|
|||||||
* 统一处理分页、筛选、搜索防抖和请求取消
|
* 统一处理分页、筛选、搜索防抖和请求取消
|
||||||
*/
|
*/
|
||||||
export function useTableLoader<T, P extends Record<string, any>>(options: TableLoaderOptions<T, P>) {
|
export function useTableLoader<T, P extends Record<string, any>>(options: TableLoaderOptions<T, P>) {
|
||||||
const { fetchFn, initialParams, pageSize = 20, debounceMs = 300 } = options
|
const { fetchFn, initialParams, pageSize, debounceMs = 300 } = options
|
||||||
|
|
||||||
const items = ref<T[]>([])
|
const items = ref<T[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const params = reactive<P>({ ...(initialParams || {}) } as P)
|
const params = reactive<P>({ ...(initialParams || {}) } as P)
|
||||||
const pagination = reactive<PaginationState>({
|
const pagination = reactive<PaginationState>({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: pageSize,
|
page_size: pageSize ?? getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
@@ -87,6 +88,7 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
|||||||
const handlePageSizeChange = (size: number) => {
|
const handlePageSizeChange = (size: number) => {
|
||||||
pagination.page_size = size
|
pagination.page_size = size
|
||||||
pagination.page = 1
|
pagination.page = 1
|
||||||
|
setPersistedPageSize(size)
|
||||||
load()
|
load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1883,6 +1883,7 @@ export default {
|
|||||||
allTypes: 'All Types',
|
allTypes: 'All Types',
|
||||||
allStatus: 'All Status',
|
allStatus: 'All Status',
|
||||||
allGroups: 'All Groups',
|
allGroups: 'All Groups',
|
||||||
|
ungroupedGroup: 'Ungrouped',
|
||||||
oauthType: 'OAuth',
|
oauthType: 'OAuth',
|
||||||
setupToken: 'Setup Token',
|
setupToken: 'Setup Token',
|
||||||
apiKey: 'API Key',
|
apiKey: 'API Key',
|
||||||
@@ -2760,7 +2761,9 @@ export default {
|
|||||||
gemini3Pro: 'G3P',
|
gemini3Pro: 'G3P',
|
||||||
gemini3Flash: 'G3F',
|
gemini3Flash: 'G3F',
|
||||||
gemini3Image: 'G31FI',
|
gemini3Image: 'G31FI',
|
||||||
claude: 'Claude'
|
claude: 'Claude',
|
||||||
|
passiveSampled: 'Passive',
|
||||||
|
activeQuery: 'Query'
|
||||||
},
|
},
|
||||||
tier: {
|
tier: {
|
||||||
free: 'Free',
|
free: 'Free',
|
||||||
|
|||||||
@@ -1965,6 +1965,7 @@ export default {
|
|||||||
allTypes: '全部类型',
|
allTypes: '全部类型',
|
||||||
allStatus: '全部状态',
|
allStatus: '全部状态',
|
||||||
allGroups: '全部分组',
|
allGroups: '全部分组',
|
||||||
|
ungroupedGroup: '未分配分组',
|
||||||
oauthType: 'OAuth',
|
oauthType: 'OAuth',
|
||||||
// Schedulable toggle
|
// Schedulable toggle
|
||||||
schedulable: '参与调度',
|
schedulable: '参与调度',
|
||||||
@@ -2163,7 +2164,9 @@ export default {
|
|||||||
gemini3Pro: 'G3P',
|
gemini3Pro: 'G3P',
|
||||||
gemini3Flash: 'G3F',
|
gemini3Flash: 'G3F',
|
||||||
gemini3Image: 'G31FI',
|
gemini3Image: 'G31FI',
|
||||||
claude: 'Claude'
|
claude: 'Claude',
|
||||||
|
passiveSampled: '被动采样',
|
||||||
|
activeQuery: '查询'
|
||||||
},
|
},
|
||||||
tier: {
|
tier: {
|
||||||
free: 'Free',
|
free: 'Free',
|
||||||
|
|||||||
@@ -781,6 +781,7 @@ export interface AntigravityModelQuota {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface AccountUsageInfo {
|
export interface AccountUsageInfo {
|
||||||
|
source?: 'passive' | 'active'
|
||||||
updated_at: string | null
|
updated_at: string | null
|
||||||
five_hour: UsageProgress | null
|
five_hour: UsageProgress | null
|
||||||
seven_day: UsageProgress | null
|
seven_day: UsageProgress | null
|
||||||
|
|||||||
@@ -758,6 +758,7 @@ const refreshAccountsIncrementally = async () => {
|
|||||||
platform?: string
|
platform?: string
|
||||||
type?: string
|
type?: string
|
||||||
status?: string
|
status?: string
|
||||||
|
group?: string
|
||||||
search?: string
|
search?: string
|
||||||
|
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -239,6 +239,7 @@
|
|||||||
import { computed, onMounted, reactive, ref } from 'vue'
|
import { computed, onMounted, reactive, ref } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import type { AdminGroup, Announcement, AnnouncementTargeting } from '@/types'
|
import type { AdminGroup, Announcement, AnnouncementTargeting } from '@/types'
|
||||||
@@ -270,7 +271,7 @@ const searchQuery = ref('')
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1855,6 +1855,7 @@ import GroupCapacityBadge from '@/components/common/GroupCapacityBadge.vue'
|
|||||||
import { VueDraggable } from 'vue-draggable-plus'
|
import { VueDraggable } from 'vue-draggable-plus'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
@@ -2016,7 +2017,7 @@ const filters = reactive({
|
|||||||
})
|
})
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -383,6 +383,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
|||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import type { PromoCode, PromoCodeUsage } from '@/types'
|
import type { PromoCode, PromoCodeUsage } from '@/types'
|
||||||
@@ -414,7 +415,7 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0
|
total: 0
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -884,6 +884,7 @@ import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
|
|||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { useSwipeSelect } from '@/composables/useSwipeSelect'
|
import { useSwipeSelect } from '@/composables/useSwipeSelect'
|
||||||
import { useTableSelection } from '@/composables/useTableSelection'
|
import { useTableSelection } from '@/composables/useTableSelection'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
@@ -941,7 +942,7 @@ const filters = reactive({
|
|||||||
})
|
})
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -395,6 +395,7 @@ import { ref, reactive, computed, onMounted, onUnmounted, watch } from 'vue'
|
|||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import type { RedeemCode, RedeemCodeType, Group, GroupPlatform, SubscriptionType } from '@/types'
|
import type { RedeemCode, RedeemCodeType, Group, GroupPlatform, SubscriptionType } from '@/types'
|
||||||
@@ -532,7 +533,7 @@ const filters = reactive({
|
|||||||
})
|
})
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -744,6 +744,7 @@ import type { UserSubscription, Group, GroupPlatform, SubscriptionType } from '@
|
|||||||
import type { SimpleUser } from '@/api/admin/usage'
|
import type { SimpleUser } from '@/api/admin/usage'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
import { formatDateOnly } from '@/utils/format'
|
import { formatDateOnly } from '@/utils/format'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||||
import DataTable from '@/components/common/DataTable.vue'
|
import DataTable from '@/components/common/DataTable.vue'
|
||||||
@@ -928,7 +929,7 @@ const sortState = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import { saveAs } from 'file-saver'
|
import { saveAs } from 'file-saver'
|
||||||
import { useRoute } from 'vue-router'
|
import { useRoute } from 'vue-router'
|
||||||
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
|
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { formatReasoningEffort } from '@/utils/format'
|
import { formatReasoningEffort } from '@/utils/format'
|
||||||
import { resolveUsageRequestType, requestTypeToLegacyStream } from '@/utils/usageRequestType'
|
import { resolveUsageRequestType, requestTypeToLegacyStream } from '@/utils/usageRequestType'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue'; import DateRangePicker from '@/components/common/DateRangePicker.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue'; import DateRangePicker from '@/components/common/DateRangePicker.vue'
|
||||||
@@ -203,7 +204,7 @@ const getGranularityForRange = (start: string, end: string): 'day' | 'hour' => {
|
|||||||
const defaultRange = getLast24HoursRangeDates()
|
const defaultRange = getLast24HoursRangeDates()
|
||||||
const startDate = ref(defaultRange.start); const endDate = ref(defaultRange.end)
|
const startDate = ref(defaultRange.start); const endDate = ref(defaultRange.end)
|
||||||
const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, request_type: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
|
const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, request_type: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
|
||||||
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
|
const pagination = reactive({ page: 1, page_size: getPersistedPageSize(), total: 0 })
|
||||||
|
|
||||||
const getSingleQueryValue = (value: string | null | Array<string | null> | undefined): string | undefined => {
|
const getSingleQueryValue = (value: string | null | Array<string | null> | undefined): string | undefined => {
|
||||||
if (Array.isArray(value)) return value.find((item): item is string => typeof item === 'string' && item.length > 0)
|
if (Array.isArray(value)) return value.find((item): item is string => typeof item === 'string' && item.length > 0)
|
||||||
|
|||||||
@@ -521,6 +521,7 @@
|
|||||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { formatDateTime } from '@/utils/format'
|
import { formatDateTime } from '@/utils/format'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
|
||||||
@@ -774,7 +775,7 @@ const attributeDefinitions = ref<UserAttributeDefinition[]>([])
|
|||||||
const userAttributeValues = ref<Record<number, Record<number, string>>>({})
|
const userAttributeValues = ref<Record<number, Record<number, string>>>({})
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1035,6 +1035,7 @@
|
|||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useOnboardingStore } from '@/stores/onboarding'
|
import { useOnboardingStore } from '@/stores/onboarding'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
import { keysAPI, authAPI, usageAPI, userGroupsAPI } from '@/api'
|
import { keysAPI, authAPI, usageAPI, userGroupsAPI } from '@/api'
|
||||||
@@ -1101,7 +1102,7 @@ const userGroupRates = ref<Record<number, number>>({})
|
|||||||
|
|
||||||
const pagination = ref({
|
const pagination = ref({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 10,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -496,6 +496,7 @@ import Icon from '@/components/icons/Icon.vue'
|
|||||||
import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types'
|
import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
|
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
|
||||||
|
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||||
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
|
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
|
||||||
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
|
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
|
||||||
import { resolveUsageRequestType } from '@/utils/usageRequestType'
|
import { resolveUsageRequestType } from '@/utils/usageRequestType'
|
||||||
@@ -584,7 +585,7 @@ const onDateRangeChange = (range: {
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
page: 1,
|
page: 1,
|
||||||
page_size: 20,
|
page_size: getPersistedPageSize(),
|
||||||
total: 0,
|
total: 0,
|
||||||
pages: 0
|
pages: 0
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user