Files
sub2api/backend/internal/service/openai_tool_continuation.go

297 lines
8.1 KiB
Go
Raw Permalink Normal View History

package service
import "strings"
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
type ToolContinuationSignals struct {
HasFunctionCallOutput bool
HasFunctionCallOutputMissingCallID bool
HasToolCallContext bool
HasItemReference bool
HasItemReferenceForAllCallIDs bool
FunctionCallOutputCallIDs []string
}
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
type FunctionCallOutputValidation struct {
HasFunctionCallOutput bool
HasToolCallContext bool
HasFunctionCallOutputMissingCallID bool
HasItemReferenceForAllCallIDs bool
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
return true
}
}
return false
}
// AnalyzeToolContinuationSignals 单次遍历 input提取 function_call_output/tool_call/item_reference 相关信号。
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
signals := ToolContinuationSignals{}
if reqBody == nil {
return signals
}
input, ok := reqBody["input"].([]any)
if !ok {
return signals
}
var callIDs map[string]struct{}
var referenceIDs map[string]struct{}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
signals.HasToolCallContext = true
}
case "function_call_output":
signals.HasFunctionCallOutput = true
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
signals.HasFunctionCallOutputMissingCallID = true
continue
}
if callIDs == nil {
callIDs = make(map[string]struct{})
}
callIDs[callID] = struct{}{}
case "item_reference":
signals.HasItemReference = true
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
if referenceIDs == nil {
referenceIDs = make(map[string]struct{})
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 {
return signals
}
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
allReferenced := len(referenceIDs) > 0
for callID := range callIDs {
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
if allReferenced {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
}
}
}
signals.HasItemReferenceForAllCallIDs = allReferenced
return signals
}
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
// 1) 无 function_call_output 直接返回
// 2) 若已存在 tool_call/function_call 上下文则提前返回
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
result := FunctionCallOutputValidation{}
if reqBody == nil {
return result
}
input, ok := reqBody["input"].([]any)
if !ok {
return result
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
result.HasFunctionCallOutput = true
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
result.HasToolCallContext = true
}
}
if result.HasFunctionCallOutput && result.HasToolCallContext {
return result
}
}
if !result.HasFunctionCallOutput || result.HasToolCallContext {
return result
}
callIDs := make(map[string]struct{})
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
result.HasFunctionCallOutputMissingCallID = true
continue
}
callIDs[callID] = struct{}{}
case "item_reference":
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 || len(referenceIDs) == 0 {
return result
}
allReferenced := true
for callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
break
}
}
result.HasItemReferenceForAllCallIDs = allReferenced
return result
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
return false
}
}
return true
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}