mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
4e8615f276 |
@@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -655,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package apicompat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
|
||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
|
||||
@@ -522,16 +522,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMapping(mapping, requestedModel)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
@@ -605,9 +612,7 @@ func matchWildcard(pattern, str string) bool {
|
||||
return matchAntigravityWildcard(pattern, str)
|
||||
}
|
||||
|
||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||
// 如果没有匹配,返回原始字符串
|
||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||
type patternMatch struct {
|
||||
pattern string
|
||||
@@ -622,7 +627,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return requestedModel // 无匹配,返回原始模型名
|
||||
return requestedModel, false // 无匹配,返回原始模型名
|
||||
}
|
||||
|
||||
// 按 pattern 长度降序排序
|
||||
@@ -633,7 +638,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
return matches[i].pattern < matches[j].pattern
|
||||
})
|
||||
|
||||
return matches[0].target
|
||||
return matches[0].target, true
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
|
||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMapping(t *testing.T) {
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
|
||||
@@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
|
||||
fixIDPrefix := func(id string) string {
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
if id == "" || strings.HasPrefix(id, "fc") {
|
||||
return id
|
||||
}
|
||||
@@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
newItem["id"] = fixIDPrefix(id)
|
||||
if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") {
|
||||
newItem["id"] = fixCallIDPrefix(id)
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
@@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
if callID != "" {
|
||||
fixedCallID := fixIDPrefix(callID)
|
||||
fixedCallID := fixCallIDPrefix(callID)
|
||||
if fixedCallID != callID {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = fixedCallID
|
||||
@@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
} else {
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
fixedID := fixIDPrefix(id)
|
||||
if fixedID != id {
|
||||
ensureCopy()
|
||||
newItem["id"] = fixedID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
|
||||
@@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "fc_ref1", first["id"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc_o1", second["id"])
|
||||
require.Equal(t, "o1", second["id"])
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||
map[string]any{"type": "item_reference", "id": "rs_123"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "msg_0", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rs_123", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||
|
||||
@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
19
backend/internal/service/openai_model_mapping.go
Normal file
19
backend/internal/service/openai_model_mapping.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
70
backend/internal/service/openai_model_mapping_test.go
Normal file
70
backend/internal/service/openai_model_mapping_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
requestedModel string
|
||||
defaultMappedModel string
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "falls back to group default when account has no mapping",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "preserves exact passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "preserves wildcard passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "uses account remap when explicit target differs",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user