mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
feat(gateway): port Parrot tool-name obfuscation + message cache breakpoints
Implements the remaining three parity items with Parrot cc_mimicry:
D) Tool-name obfuscation
- Dynamic mapping when tools.length > 5 (matches Parrot threshold).
Fake names follow {prefix}{name[:3]}{i:02d} (e.g. 'manage_bas00').
Go port of random.Random(hash(tuple(names))) uses fnv64a seed +
math/rand; byte-exact reproduction is impossible (Python hash vs
Go hash), but the two invariants that matter are preserved:
* same input tool_names yield identical mapping (cache hit)
* prefix pool is shuffled (names look distributed)
- Static prefix map (sessions_ -> cc_sess_, session_ -> cc_ses_)
applied as fallback, matching Parrot TOOL_NAME_REWRITES verbatim.
- Server tools (web_search_20250305, computer_*, etc.) are NOT
renamed; only type=='function' and type=='custom' tools are.
- tool_choice.name is rewritten in sync (only when type=='tool').
- Response side: bytes-level replace on every SSE chunk / JSON
body at 6 injection points (standard stream/non-stream,
passthrough stream/non-stream, chat_completions stream +
non-stream, responses stream + non-stream). Reverse mapping
applied longest-fake-name-first to prevent substring conflicts
(parity with Parrot _restore_tool_names_in_chunk).
- tool_choice is no longer unconditionally deleted in
normalizeClaudeOAuthRequestBody — Parrot passes it through.
E) tools[-1] cache_control breakpoint
- Injected as {type:ephemeral, ttl:<DefaultCacheControlTTL>} when
the last tool has no cache_control. Client-provided ttl is
passed through unchanged (repo-wide policy).
F) messages cache_control strategy
- stripMessageCacheControl removes every client-provided
messages[*].content[*].cache_control (multi-turn stability).
- addMessageCacheBreakpoints then injects two stable breakpoints:
(1) last message, and (2) second-to-last user turn when
messages.length >= 4.
- Combined with the system block breakpoint and tools[-1]
breakpoint, this gives exactly the 4 breakpoints Anthropic
allows per request.
Non-trivial implementation details to be aware of when rebasing:
* Two new files, no upstream collision:
gateway_tool_rewrite.go (D + E algorithms)
gateway_messages_cache.go (F strip + breakpoints)
* Two new feature calls bolted onto the tail of
applyClaudeCodeOAuthMimicryToBody in gateway_service.go — rebase
conflicts will be ~10 lines maximum.
* Response-side injection points all wrap their existing write with
reverseToolNamesIfPresent(c, ...), preserving original behavior
when no mapping is stored (static prefix rollback still runs).
* Non-stream chat/responses switched from c.JSON to
json.Marshal + c.Data so bytes-level replace is possible.
* Retry bodies (FilterThinkingBlocksForRetry,
FilterSignatureSensitiveBlocksForRetry, RectifyThinkingBudget)
only prune blocks — they preserve the already-obfuscated tool
names, so no extra mapping re-application is needed.
Manual QA: end-to-end scenario verified with 6 tools (above threshold)
and tool_choice.type=='tool'. Obfuscation + restore roundtrip shown
in test logs; then removed the temp test file.
Tests (16 new):
- buildDynamicToolMap stability + below-threshold guard
- sanitizeToolName precedence (dynamic > static)
- restoreToolNamesInBytes longest-first + static rollback
- applyToolNameRewriteToBody skips server tools + syncs tool_choice
- applyToolsLastCacheBreakpoint defaults to 5m + passes client ttl
- stripMessageCacheControl + addMessageCacheBreakpoints in the
1/4/string-content cases + second-to-last user turn selection
- buildToolNameRewriteFromBody ReverseOrdered is desc-by-fake-length
- fake name shape follows Parrot {prefix}{head3}{i:02d}
This commit is contained in:
@@ -313,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
|
|||||||
if s.responseHeaderFilter != nil {
|
if s.responseHeaderFilter != nil {
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, ccResp)
|
// Marshal then bytes-replace so tool name mapping is reversed at byte level
|
||||||
|
// (parity with Parrot non-stream flow that marshals → restore → emit).
|
||||||
|
if respBytes, err := json.Marshal(ccResp); err == nil {
|
||||||
|
respBytes = reverseToolNamesIfPresent(c, respBytes)
|
||||||
|
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, ccResp)
|
||||||
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
@@ -384,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
// Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
|
||||||
|
// c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。
|
||||||
|
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||||
|
if _, err := fmt.Fprint(c.Writer, out); err != nil {
|
||||||
return true // client disconnected
|
return true // client disconnected
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -332,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
|
|||||||
if s.responseHeaderFilter != nil {
|
if s.responseHeaderFilter != nil {
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, responsesResp)
|
if respBytes, err := json.Marshal(responsesResp); err == nil {
|
||||||
|
respBytes = reverseToolNamesIfPresent(c, respBytes)
|
||||||
|
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, responsesResp)
|
||||||
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
@@ -420,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||||
|
if _, err := fmt.Fprint(c.Writer, out); err != nil {
|
||||||
logger.L().Info("forward_as_responses stream: client disconnected",
|
logger.L().Info("forward_as_responses stream: client disconnected",
|
||||||
zap.String("request_id", requestID),
|
zap.String("request_id", requestID),
|
||||||
)
|
)
|
||||||
@@ -440,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||||
|
fmt.Fprint(c.Writer, out) //nolint:errcheck
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
141
backend/internal/service/gateway_messages_cache.go
Normal file
141
backend/internal/service/gateway_messages_cache.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
|
||||||
|
// 与 Parrot _strip_message_cache_control 语义一致。
|
||||||
|
//
|
||||||
|
// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
|
||||||
|
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
|
||||||
|
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
|
||||||
|
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
|
||||||
|
func stripMessageCacheControl(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
msgIdx := -1
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
msgIdx++
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.IsArray() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
blockIdx := -1
|
||||||
|
content.ForEach(func(_, block gjson.Result) bool {
|
||||||
|
blockIdx++
|
||||||
|
if !block.Get("cache_control").Exists() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
|
||||||
|
if next, err := sjson.DeleteBytes(body, path); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
|
||||||
|
// 1. 最后一条 message
|
||||||
|
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
|
||||||
|
//
|
||||||
|
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
|
||||||
|
// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。
|
||||||
|
//
|
||||||
|
// cache_control ttl 策略:
|
||||||
|
// - 若目标 block 已有 cache_control.ttl → 不覆盖
|
||||||
|
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
|
||||||
|
//
|
||||||
|
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
|
||||||
|
func addMessageCacheBreakpoints(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
arr := messages.Array()
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
|
||||||
|
|
||||||
|
if len(arr) >= 4 {
|
||||||
|
userCount := 0
|
||||||
|
for i := len(arr) - 1; i >= 0; i-- {
|
||||||
|
if arr[i].Get("role").String() != "user" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userCount++
|
||||||
|
if userCount == 2 {
|
||||||
|
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
|
||||||
|
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
|
||||||
|
// (对齐 Parrot _inject_cache_on_msg 的行为)。
|
||||||
|
//
|
||||||
|
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
|
||||||
|
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
|
||||||
|
content := msg.Get("content")
|
||||||
|
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
text := content.String()
|
||||||
|
blockRaw := fmt.Sprintf(
|
||||||
|
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
|
||||||
|
mustJSONString(text), claude.DefaultCacheControlTTL,
|
||||||
|
)
|
||||||
|
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if !content.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
contentArr := content.Array()
|
||||||
|
if len(contentArr) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
lastBlockIdx := len(contentArr) - 1
|
||||||
|
lastBlock := contentArr[lastBlockIdx]
|
||||||
|
|
||||||
|
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
|
||||||
|
existingCC := lastBlock.Get("cache_control")
|
||||||
|
if existingCC.Exists() {
|
||||||
|
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
|
||||||
|
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号),
|
||||||
|
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
|
||||||
|
func mustJSONString(s string) string {
|
||||||
|
return fmt.Sprintf("%q", s)
|
||||||
|
}
|
||||||
@@ -1110,10 +1110,17 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
// tool_choice:与 Parrot 对齐,不再无条件删除。
|
||||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
// - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由
|
||||||
out = next
|
// applyToolNameRewriteToBody 同步映射为假名
|
||||||
modified = true
|
// - 其他形态(auto/any/none)原样透传
|
||||||
|
// 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除
|
||||||
|
if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
|
||||||
|
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||||
|
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1214,6 +1221,25 @@ func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
|
|||||||
}
|
}
|
||||||
|
|
||||||
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
|
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
|
||||||
|
|
||||||
|
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
||||||
|
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
|
||||||
|
// 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性)
|
||||||
|
// 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn)
|
||||||
|
// 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
|
||||||
|
// 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
|
||||||
|
body = stripMessageCacheControl(body)
|
||||||
|
body = addMessageCacheBreakpoints(body)
|
||||||
|
|
||||||
|
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||||
|
body = applyToolNameRewriteToBody(body, rw)
|
||||||
|
if c != nil {
|
||||||
|
c.Set(toolNameRewriteKey, rw)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body = applyToolsLastCacheBreakpoint(body)
|
||||||
|
}
|
||||||
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5099,7 +5125,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
if _, err := io.WriteString(w, line); err != nil {
|
restored := string(reverseToolNamesIfPresent(c, []byte(line)))
|
||||||
|
if _, err := io.WriteString(w, restored); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||||
} else if _, err := io.WriteString(w, "\n"); err != nil {
|
} else if _, err := io.WriteString(w, "\n"); err != nil {
|
||||||
@@ -5269,6 +5296,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
|
body = reverseToolNamesIfPresent(c, body)
|
||||||
c.Data(resp.StatusCode, contentType, body)
|
c.Data(resp.StatusCode, contentType, body)
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
@@ -7013,7 +7041,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
|
|
||||||
for _, block := range outputBlocks {
|
for _, block := range outputBlocks {
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
if _, werr := fmt.Fprint(w, block); werr != nil {
|
restored := reverseToolNamesIfPresent(c, []byte(block))
|
||||||
|
if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
break
|
break
|
||||||
@@ -7355,6 +7384,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
body = reverseToolNamesIfPresent(c, body)
|
||||||
|
|
||||||
// 写入响应
|
// 写入响应
|
||||||
c.Data(resp.StatusCode, contentType, body)
|
c.Data(resp.StatusCode, contentType, body)
|
||||||
|
|
||||||
|
|||||||
313
backend/internal/service/gateway_tool_rewrite.go
Normal file
313
backend/internal/service/gateway_tool_rewrite.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"math/rand"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
|
||||||
|
// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
|
||||||
|
const toolNameRewriteKey = "claude_tool_name_rewrite"
|
||||||
|
|
||||||
|
// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
|
||||||
|
// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
|
||||||
|
var staticToolNameRewrites = map[string]string{
|
||||||
|
"sessions_": "cc_sess_",
|
||||||
|
"session_": "cc_ses_",
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
|
||||||
|
// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
|
||||||
|
var fakeToolNamePrefixes = []string{
|
||||||
|
"analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
|
||||||
|
"process_", "query_", "render_", "resolve_", "sync_", "update_",
|
||||||
|
"validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
|
||||||
|
"review_", "search_", "transform_", "handle_", "invoke_", "notify_",
|
||||||
|
}
|
||||||
|
|
||||||
|
// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。
|
||||||
|
// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
|
||||||
|
const dynamicToolMapThreshold = 5
|
||||||
|
|
||||||
|
// ToolNameRewrite 是单次请求内的工具名混淆映射。
|
||||||
|
// - Forward: real → fake,请求阶段在 body 上应用。
|
||||||
|
// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。
|
||||||
|
//
|
||||||
|
// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
|
||||||
|
// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
|
||||||
|
// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
|
||||||
|
type ToolNameRewrite struct {
|
||||||
|
Forward map[string]string
|
||||||
|
Reverse map[string]string
|
||||||
|
ReverseOrdered [][2]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDynamicToolMap 构造 tools 的动态假名映射。
|
||||||
|
//
|
||||||
|
// 与 Parrot _build_dynamic_tool_map 语义等价:
|
||||||
|
// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback)
|
||||||
|
// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
|
||||||
|
//
|
||||||
|
// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
|
||||||
|
// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留:
|
||||||
|
// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
|
||||||
|
// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。
|
||||||
|
func buildDynamicToolMap(toolNames []string) map[string]string {
|
||||||
|
if len(toolNames) <= dynamicToolMapThreshold {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h := fnv.New64a()
|
||||||
|
for i, n := range toolNames {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = h.Write([]byte{0})
|
||||||
|
}
|
||||||
|
_, _ = h.Write([]byte(n))
|
||||||
|
}
|
||||||
|
rng := rand.New(rand.NewSource(int64(h.Sum64())))
|
||||||
|
|
||||||
|
available := make([]string, len(fakeToolNamePrefixes))
|
||||||
|
copy(available, fakeToolNamePrefixes)
|
||||||
|
rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
|
||||||
|
|
||||||
|
mapping := make(map[string]string, len(toolNames))
|
||||||
|
for i, name := range toolNames {
|
||||||
|
prefix := available[i%len(available)]
|
||||||
|
headLen := 3
|
||||||
|
if len(name) < 3 {
|
||||||
|
headLen = len(name)
|
||||||
|
}
|
||||||
|
fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
|
||||||
|
mapping[name] = fake
|
||||||
|
}
|
||||||
|
return mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeToolName 把真名转成假名。
|
||||||
|
// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
|
||||||
|
func sanitizeToolName(name string, dynamic map[string]string) string {
|
||||||
|
if dynamic != nil {
|
||||||
|
if fake, ok := dynamic[name]; ok {
|
||||||
|
return fake
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for prefix, replacement := range staticToolNameRewrites {
|
||||||
|
if strings.HasPrefix(name, prefix) {
|
||||||
|
return replacement + name[len(prefix):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldMimicToolName 指示某个 tool 是否需要重命名。
|
||||||
|
// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
|
||||||
|
// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
|
||||||
|
func shouldMimicToolName(toolType string) bool {
|
||||||
|
if toolType == "" || toolType == "function" || toolType == "custom" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite
|
||||||
|
// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
|
||||||
|
//
|
||||||
|
// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
|
||||||
|
func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.IsArray() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mimicableNames := make([]string, 0)
|
||||||
|
toolsArr := tools.Array()
|
||||||
|
for _, t := range toolsArr {
|
||||||
|
if !shouldMimicToolName(t.Get("type").String()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := t.Get("name").String()
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimicableNames = append(mimicableNames, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
dynamic := buildDynamicToolMap(mimicableNames)
|
||||||
|
|
||||||
|
rw := &ToolNameRewrite{
|
||||||
|
Forward: make(map[string]string),
|
||||||
|
Reverse: make(map[string]string),
|
||||||
|
}
|
||||||
|
for _, name := range mimicableNames {
|
||||||
|
fake := sanitizeToolName(name, dynamic)
|
||||||
|
if fake == name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rw.Forward[name] = fake
|
||||||
|
rw.Reverse[fake] = name
|
||||||
|
}
|
||||||
|
if len(rw.Forward) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
|
||||||
|
for fake, real := range rw.Reverse {
|
||||||
|
rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
|
||||||
|
}
|
||||||
|
sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
|
||||||
|
return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
|
||||||
|
})
|
||||||
|
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
|
||||||
|
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
|
||||||
|
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
|
||||||
|
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
|
||||||
|
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
|
||||||
|
//
|
||||||
|
// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
|
||||||
|
// 响应侧 bytes.Replace 会连带还原它们。
|
||||||
|
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
|
||||||
|
if rw == nil || len(rw.Forward) == 0 {
|
||||||
|
body = applyToolsLastCacheBreakpoint(body)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
idx := -1
|
||||||
|
tools.ForEach(func(_, t gjson.Result) bool {
|
||||||
|
idx++
|
||||||
|
if !shouldMimicToolName(t.Get("type").String()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
name := t.Get("name").String()
|
||||||
|
if name == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
fake, ok := rw.Forward[name]
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
|
||||||
|
name := tc.Get("name").String()
|
||||||
|
if fake, ok := rw.Forward[name]; ok {
|
||||||
|
if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body = applyToolsLastCacheBreakpoint(body)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
|
||||||
|
// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
|
||||||
|
// 行为,但 ttl 按本仓规则:
|
||||||
|
// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
|
||||||
|
// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
|
||||||
|
//
|
||||||
|
// 纯副作用函数,tools 不存在或为空数组时 no-op。
|
||||||
|
func applyToolsLastCacheBreakpoint(body []byte) []byte {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
arr := tools.Array()
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
lastIdx := len(arr) - 1
|
||||||
|
existingCC := arr[lastIdx].Get("cache_control")
|
||||||
|
|
||||||
|
if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingCC.Exists() {
|
||||||
|
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
|
||||||
|
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
|
||||||
|
// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突
|
||||||
|
// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
|
||||||
|
// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。
|
||||||
|
//
|
||||||
|
// rw 可为 nil;nil 时仍会做静态前缀还原。
|
||||||
|
func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
|
||||||
|
if rw != nil {
|
||||||
|
for _, pair := range rw.ReverseOrdered {
|
||||||
|
fake, real := pair[0], pair[1]
|
||||||
|
if fake == "" || fake == real {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = replaceAllBytes(data, fake, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for prefix, replacement := range staticToolNameRewrites {
|
||||||
|
data = replaceAllBytes(data, replacement, prefix)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
|
||||||
|
func replaceAllBytes(data []byte, from, to string) []byte {
|
||||||
|
if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return []byte(strings.ReplaceAll(string(data), from, to))
|
||||||
|
}
|
||||||
|
|
||||||
|
// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
|
||||||
|
// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。
|
||||||
|
func toolNameRewriteFromContext(c interface {
|
||||||
|
Get(string) (any, bool)
|
||||||
|
}) *ToolNameRewrite {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := c.Get(toolNameRewriteKey)
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rw, _ := raw.(*ToolNameRewrite)
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
|
// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
|
||||||
|
// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
|
||||||
|
func reverseToolNamesIfPresent(c interface {
|
||||||
|
Get(string) (any, bool)
|
||||||
|
}, chunk []byte) []byte {
|
||||||
|
rw := toolNameRewriteFromContext(c)
|
||||||
|
if rw == nil && len(staticToolNameRewrites) == 0 {
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
return restoreToolNamesInBytes(chunk, rw)
|
||||||
|
}
|
||||||
185
backend/internal/service/gateway_tool_rewrite_test.go
Normal file
185
backend/internal/service/gateway_tool_rewrite_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
|
||||||
|
// Parrot 行为:tools 数量 ≤ 5 时不做动态映射。
|
||||||
|
names := []string{"bash", "edit", "read", "write", "search"}
|
||||||
|
require.Nil(t, buildDynamicToolMap(names))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
|
||||||
|
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
|
||||||
|
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
|
||||||
|
a := buildDynamicToolMap(names)
|
||||||
|
b := buildDynamicToolMap(names)
|
||||||
|
require.NotNil(t, a)
|
||||||
|
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
|
||||||
|
require.Len(t, a, 6)
|
||||||
|
for _, name := range names {
|
||||||
|
require.Contains(t, a, name)
|
||||||
|
require.NotEqual(t, name, a[name])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
|
||||||
|
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
|
||||||
|
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
|
||||||
|
require.Equal(t, "bash", sanitizeToolName("bash", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
|
||||||
|
dyn := map[string]string{"sessions_list": "analyze_ses00"}
|
||||||
|
got := sanitizeToolName("sessions_list", dyn)
|
||||||
|
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
|
||||||
|
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
|
||||||
|
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
|
||||||
|
rw := &ToolNameRewrite{
|
||||||
|
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
|
||||||
|
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
|
||||||
|
}
|
||||||
|
// 手工构造 ReverseOrdered:长的在前
|
||||||
|
rw.ReverseOrdered = [][2]string{
|
||||||
|
{"abc_12_ext", "bar"},
|
||||||
|
{"abc_12", "foo"},
|
||||||
|
}
|
||||||
|
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
|
||||||
|
restored := string(restoreToolNamesInBytes(data, rw))
|
||||||
|
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
|
||||||
|
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
|
||||||
|
got := string(restoreToolNamesInBytes(data, nil))
|
||||||
|
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
|
||||||
|
rw := buildToolNameRewriteFromBody(body)
|
||||||
|
require.NotNil(t, rw)
|
||||||
|
require.Contains(t, rw.Forward, "sessions_list")
|
||||||
|
require.Contains(t, rw.Forward, "session_get")
|
||||||
|
// web_search is a server tool, not rewritten
|
||||||
|
require.NotContains(t, rw.Forward, "web_search")
|
||||||
|
|
||||||
|
out := applyToolNameRewriteToBody(body, rw)
|
||||||
|
|
||||||
|
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
|
||||||
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
|
||||||
|
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
|
||||||
|
|
||||||
|
// tool_choice.name rewritten
|
||||||
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
|
||||||
|
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
|
||||||
|
out := applyToolsLastCacheBreakpoint(body)
|
||||||
|
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
|
||||||
|
// First tool untouched
|
||||||
|
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
|
||||||
|
out := applyToolsLastCacheBreakpoint(body)
|
||||||
|
// User-provided ttl must be preserved.
|
||||||
|
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripMessageCacheControl(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
|
||||||
|
out := stripMessageCacheControl(body)
|
||||||
|
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
out := addMessageCacheBreakpoints(body)
|
||||||
|
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
|
||||||
|
// Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
|
||||||
|
body := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"q1"}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
|
||||||
|
{"role":"user","content":[{"type":"text","text":"q2"}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
|
||||||
|
]}`)
|
||||||
|
out := addMessageCacheBreakpoints(body)
|
||||||
|
// 最后一条 assistant 被打断点
|
||||||
|
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
|
||||||
|
// 倒数第二个 user turn = index 0(唯一另一个 user)
|
||||||
|
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
|
||||||
|
// 其他不打断点
|
||||||
|
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := addMessageCacheBreakpoints(body)
|
||||||
|
// content 升级成数组
|
||||||
|
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
|
||||||
|
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
|
||||||
|
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
|
||||||
|
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
|
||||||
|
body := []byte(`{"tools":[
|
||||||
|
{"name":"t1","input_schema":{}},
|
||||||
|
{"name":"t2","input_schema":{}},
|
||||||
|
{"name":"t3","input_schema":{}},
|
||||||
|
{"name":"t4","input_schema":{}},
|
||||||
|
{"name":"t5","input_schema":{}},
|
||||||
|
{"name":"t6","input_schema":{}}
|
||||||
|
]}`)
|
||||||
|
rw := buildToolNameRewriteFromBody(body)
|
||||||
|
require.NotNil(t, rw)
|
||||||
|
require.NotEmpty(t, rw.ReverseOrdered)
|
||||||
|
for i := 1; i < len(rw.ReverseOrdered); i++ {
|
||||||
|
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
|
||||||
|
"ReverseOrdered must be sorted by fake-name length descending")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
|
||||||
|
data := []byte("plain text without any tool names")
|
||||||
|
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
|
||||||
|
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
|
||||||
|
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
|
||||||
|
m := buildDynamicToolMap(names)
|
||||||
|
require.NotNil(t, m)
|
||||||
|
for _, name := range names {
|
||||||
|
fake, ok := m[name]
|
||||||
|
require.True(t, ok)
|
||||||
|
// fake = prefix + head3 + "%02d"
|
||||||
|
// ends with two decimal digits
|
||||||
|
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
|
||||||
|
head := name
|
||||||
|
if len(head) > 3 {
|
||||||
|
head = head[:3]
|
||||||
|
}
|
||||||
|
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user