mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-02 22:42:14 +08:00
Merge branch 'test' into dev
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
@@ -49,22 +49,9 @@ func init() {
|
||||
|
||||
// initLogger configures the default slog handler based on gin.Mode().
|
||||
// In non-release mode, Debug level logs are enabled.
|
||||
func initLogger() {
|
||||
var level slog.Level
|
||||
if gin.Mode() == gin.ReleaseMode {
|
||||
level = slog.LevelInfo
|
||||
} else {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Initialize slog logger based on gin mode
|
||||
initLogger()
|
||||
logger.InitBootstrap()
|
||||
defer logger.Sync()
|
||||
|
||||
// Parse command line flags
|
||||
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
|
||||
@@ -141,6 +128,9 @@ func runMainServer() {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil {
|
||||
log.Fatalf("Failed to initialize logger: %v", err)
|
||||
}
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ func provideCleanup(
|
||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
opsSystemLogSink *service.OpsSystemLogSink,
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
@@ -103,6 +104,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"OpsSystemLogSink", func() error {
|
||||
if opsSystemLogSink != nil {
|
||||
opsSystemLogSink.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SoraMediaCleanupService", func() error {
|
||||
if soraMediaCleanup != nil {
|
||||
soraMediaCleanup.Stop()
|
||||
|
||||
@@ -160,7 +160,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
@@ -204,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -234,6 +235,7 @@ func provideCleanup(
|
||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
opsSystemLogSink *service.OpsSystemLogSink,
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
@@ -269,6 +271,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"OpsSystemLogSink", func() error {
|
||||
if opsSystemLogSink != nil {
|
||||
opsSystemLogSink.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SoraMediaCleanupService", func() error {
|
||||
if soraMediaCleanup != nil {
|
||||
soraMediaCleanup.Stop()
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.25.7
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/dgraph-io/ristretto v0.2.0
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
@@ -13,6 +14,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/imroc/req/v3 v3.57.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/refraction-networking/utls v1.8.1
|
||||
@@ -25,10 +27,12 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.39.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
)
|
||||
@@ -45,7 +49,6 @@ require (
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
@@ -104,7 +107,6 @@ require (
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
@@ -18,6 +18,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
||||
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
@@ -137,8 +139,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -174,8 +174,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -209,8 +207,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -240,8 +236,6 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -264,8 +258,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
@@ -342,10 +334,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
|
||||
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
|
||||
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -393,6 +389,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
@@ -68,6 +69,38 @@ type Config struct {
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Format string `mapstructure:"format"`
|
||||
ServiceName string `mapstructure:"service_name"`
|
||||
Environment string `mapstructure:"env"`
|
||||
Caller bool `mapstructure:"caller"`
|
||||
StacktraceLevel string `mapstructure:"stacktrace_level"`
|
||||
Output LogOutputConfig `mapstructure:"output"`
|
||||
Rotation LogRotationConfig `mapstructure:"rotation"`
|
||||
Sampling LogSamplingConfig `mapstructure:"sampling"`
|
||||
}
|
||||
|
||||
type LogOutputConfig struct {
|
||||
ToStdout bool `mapstructure:"to_stdout"`
|
||||
ToFile bool `mapstructure:"to_file"`
|
||||
FilePath string `mapstructure:"file_path"`
|
||||
}
|
||||
|
||||
type LogRotationConfig struct {
|
||||
MaxSizeMB int `mapstructure:"max_size_mb"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAgeDays int `mapstructure:"max_age_days"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
LocalTime bool `mapstructure:"local_time"`
|
||||
}
|
||||
|
||||
type LogSamplingConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Initial int `mapstructure:"initial"`
|
||||
Thereafter int `mapstructure:"thereafter"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
|
||||
Quota GeminiQuotaConfig `mapstructure:"quota"`
|
||||
@@ -756,6 +789,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
|
||||
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
|
||||
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
|
||||
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
|
||||
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||||
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
@@ -766,7 +805,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
@@ -786,19 +825,19 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
}
|
||||
|
||||
if !cfg.Security.URLAllowlist.Enabled {
|
||||
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
|
||||
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
|
||||
}
|
||||
if !cfg.Security.ResponseHeaders.Enabled {
|
||||
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||||
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||||
}
|
||||
|
||||
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
}
|
||||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||||
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
cfg.Security.ResponseHeaders.ForceRemove,
|
||||
slog.Info("response header policy configured",
|
||||
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
"force_remove", cfg.Security.ResponseHeaders.ForceRemove,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -825,6 +864,25 @@ func setDefaults() {
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
|
||||
|
||||
// Log
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "console")
|
||||
viper.SetDefault("log.service_name", "sub2api")
|
||||
viper.SetDefault("log.env", "production")
|
||||
viper.SetDefault("log.caller", true)
|
||||
viper.SetDefault("log.stacktrace_level", "error")
|
||||
viper.SetDefault("log.output.to_stdout", true)
|
||||
viper.SetDefault("log.output.to_file", true)
|
||||
viper.SetDefault("log.output.file_path", "")
|
||||
viper.SetDefault("log.rotation.max_size_mb", 100)
|
||||
viper.SetDefault("log.rotation.max_backups", 10)
|
||||
viper.SetDefault("log.rotation.max_age_days", 7)
|
||||
viper.SetDefault("log.rotation.compress", true)
|
||||
viper.SetDefault("log.rotation.local_time", true)
|
||||
viper.SetDefault("log.sampling.enabled", false)
|
||||
viper.SetDefault("log.sampling.initial", 100)
|
||||
viper.SetDefault("log.sampling.thereafter", 100)
|
||||
|
||||
// CORS
|
||||
viper.SetDefault("cors.allowed_origins", []string{})
|
||||
viper.SetDefault("cors.allow_credentials", true)
|
||||
@@ -1098,6 +1156,54 @@ func (c *Config) Validate() error {
|
||||
if len([]byte(jwtSecret)) < 32 {
|
||||
return fmt.Errorf("jwt.secret must be at least 32 bytes")
|
||||
}
|
||||
switch c.Log.Level {
|
||||
case "debug", "info", "warn", "error":
|
||||
case "":
|
||||
return fmt.Errorf("log.level is required")
|
||||
default:
|
||||
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
|
||||
}
|
||||
switch c.Log.Format {
|
||||
case "json", "console":
|
||||
case "":
|
||||
return fmt.Errorf("log.format is required")
|
||||
default:
|
||||
return fmt.Errorf("log.format must be one of: json/console")
|
||||
}
|
||||
switch c.Log.StacktraceLevel {
|
||||
case "none", "error", "fatal":
|
||||
case "":
|
||||
return fmt.Errorf("log.stacktrace_level is required")
|
||||
default:
|
||||
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
|
||||
}
|
||||
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
|
||||
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
|
||||
}
|
||||
if c.Log.Rotation.MaxSizeMB <= 0 {
|
||||
return fmt.Errorf("log.rotation.max_size_mb must be positive")
|
||||
}
|
||||
if c.Log.Rotation.MaxBackups < 0 {
|
||||
return fmt.Errorf("log.rotation.max_backups must be non-negative")
|
||||
}
|
||||
if c.Log.Rotation.MaxAgeDays < 0 {
|
||||
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
|
||||
}
|
||||
if c.Log.Sampling.Enabled {
|
||||
if c.Log.Sampling.Initial <= 0 {
|
||||
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
|
||||
}
|
||||
if c.Log.Sampling.Thereafter <= 0 {
|
||||
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
|
||||
}
|
||||
} else {
|
||||
if c.Log.Sampling.Initial < 0 {
|
||||
return fmt.Errorf("log.sampling.initial must be non-negative")
|
||||
}
|
||||
if c.Log.Sampling.Thereafter < 0 {
|
||||
return fmt.Errorf("log.sampling.thereafter must be non-negative")
|
||||
}
|
||||
}
|
||||
|
||||
if c.SubscriptionMaintenance.WorkerCount < 0 {
|
||||
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
|
||||
@@ -1137,20 +1243,20 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
|
||||
}
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
|
||||
}
|
||||
// JWT Refresh Token配置验证
|
||||
if c.JWT.AccessTokenExpireMinutes <= 0 {
|
||||
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
|
||||
}
|
||||
if c.JWT.AccessTokenExpireMinutes > 720 {
|
||||
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
|
||||
slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays <= 0 {
|
||||
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays > 90 {
|
||||
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
|
||||
slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
|
||||
}
|
||||
if c.JWT.RefreshWindowMinutes < 0 {
|
||||
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
|
||||
@@ -1445,7 +1551,7 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.IdleConnTimeoutSeconds > 180 {
|
||||
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
|
||||
slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
|
||||
}
|
||||
if c.Gateway.MaxUpstreamClients <= 0 {
|
||||
return fmt.Errorf("gateway.max_upstream_clients must be positive")
|
||||
@@ -1682,6 +1788,6 @@ func warnIfInsecureURL(field, raw string) {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -965,6 +965,37 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
},
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
|
||||
},
|
||||
{
|
||||
name: "log level invalid",
|
||||
mutate: func(c *Config) { c.Log.Level = "trace" },
|
||||
wantErr: "log.level",
|
||||
},
|
||||
{
|
||||
name: "log format invalid",
|
||||
mutate: func(c *Config) { c.Log.Format = "plain" },
|
||||
wantErr: "log.format",
|
||||
},
|
||||
{
|
||||
name: "log output disabled",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Output.ToStdout = false
|
||||
c.Log.Output.ToFile = false
|
||||
},
|
||||
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
|
||||
},
|
||||
{
|
||||
name: "log rotation size",
|
||||
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
|
||||
wantErr: "log.rotation.max_size_mb",
|
||||
},
|
||||
{
|
||||
name: "log sampling enabled invalid",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Sampling.Enabled = true
|
||||
c.Log.Sampling.Initial = 0
|
||||
},
|
||||
wantErr: "log.sampling.initial",
|
||||
},
|
||||
{
|
||||
name: "ops metrics collector ttl",
|
||||
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func newTestSettingRepo() *testSettingRepo {
|
||||
return &testSettingRepo{values: map[string]string{}}
|
||||
}
|
||||
|
||||
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
v, err := s.GetValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
v, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, k := range keys {
|
||||
if v, ok := s.values[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
for k, v := range settings {
|
||||
s.values[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(s.values))
|
||||
for k, v := range s.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(s.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
|
||||
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
|
||||
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
|
||||
return r
|
||||
}
|
||||
|
||||
func newRuntimeOpsService(t *testing.T) *service.OpsService {
|
||||
t.Helper()
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
settingRepo := newTestSettingRepo()
|
||||
cfg := &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: true},
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, true)
|
||||
|
||||
payload := map[string]any{
|
||||
"level": "debug",
|
||||
"enable_sampling": false,
|
||||
"sampling_initial": 100,
|
||||
"sampling_thereafter": 100,
|
||||
"caller": true,
|
||||
"stacktrace_level": "error",
|
||||
"retention_days": 30,
|
||||
}
|
||||
raw, _ := json.Marshal(payload)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetRuntimeLogConfig returns runtime log config (DB-backed).
|
||||
// GET /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
|
||||
// PUT /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var req service.OpsRuntimeLogConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
|
||||
// POST /api/v1/admin/ops/runtime/logging/reset
|
||||
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
|
||||
// GET /api/v1/admin/ops/advanced-settings
|
||||
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
|
||||
|
||||
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type opsSystemLogCleanupRequest struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
|
||||
Level string `json:"level"`
|
||||
Component string `json:"component"`
|
||||
RequestID string `json:"request_id"`
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
Query string `json:"q"`
|
||||
}
|
||||
|
||||
// ListSystemLogs returns indexed system logs.
|
||||
// GET /api/v1/admin/ops/system-logs
|
||||
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 200 {
|
||||
pageSize = 200
|
||||
}
|
||||
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: strings.TrimSpace(c.Query("level")),
|
||||
Component: strings.TrimSpace(c.Query("component")),
|
||||
RequestID: strings.TrimSpace(c.Query("request_id")),
|
||||
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
Query: strings.TrimSpace(c.Query("q")),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
filter.UserID = &id
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
|
||||
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// CleanupSystemLogs deletes indexed system logs by filter.
|
||||
// POST /api/v1/admin/ops/system-logs/cleanup
|
||||
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req opsSystemLogCleanupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
parseTS := func(raw string) (*time.Time, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
|
||||
return &t, nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
start, err := parseTS(req.StartTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time")
|
||||
return
|
||||
}
|
||||
end, err := parseTS(req.EndTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time")
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Level: strings.TrimSpace(req.Level),
|
||||
Component: strings.TrimSpace(req.Component),
|
||||
RequestID: strings.TrimSpace(req.RequestID),
|
||||
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Platform: strings.TrimSpace(req.Platform),
|
||||
Model: strings.TrimSpace(req.Model),
|
||||
Query: strings.TrimSpace(req.Query),
|
||||
}
|
||||
|
||||
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": deleted})
|
||||
}
|
||||
|
||||
// GetSystemLogIngestionHealth returns sink health metrics.
|
||||
// GET /api/v1/admin/ops/system-logs/health
|
||||
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, h.opsService.GetSystemLogSinkHealth())
|
||||
}
|
||||
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type responseEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/logs", handler.ListSystemLogs)
|
||||
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
|
||||
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp responseEnvelope
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if resp.Code != 0 {
|
||||
t.Fatalf("unexpected response code: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||
sink := service.NewOpsSystemLogSink(nil)
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h = NewOpsHandler(svc)
|
||||
r = newOpsSystemLogTestRouter(h, false)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
|
||||
if err != nil || stats == nil {
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
|
||||
msg, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
// Reserve a global slot before upgrading the connection to keep the limit strict.
|
||||
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Printf("[OpsWS] read failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
continue
|
||||
}
|
||||
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||
log.Printf("[OpsWS] write failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
case <-pingTicker.C:
|
||||
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("[OpsWS] ping failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.TrustProxy = parsed
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||
prefixes, invalid := parseTrustedProxyList(raw)
|
||||
if len(invalid) > 0 {
|
||||
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
}
|
||||
cfg.TrustedProxies = prefixes
|
||||
}
|
||||
@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||
cfg.OriginPolicy = normalized
|
||||
default:
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
cfg.MaxConns = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
|
||||
cfg.MaxConnsPerIP = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -378,11 +378,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -390,7 +390,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -472,7 +472,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
@@ -488,12 +488,12 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
}
|
||||
|
||||
@@ -515,12 +515,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,11 +18,13 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
@@ -98,6 +99,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -124,6 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
@@ -163,9 +172,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -182,7 +192,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -199,7 +209,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
@@ -251,7 +261,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -260,7 +270,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -276,7 +289,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
@@ -304,9 +317,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gateway.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -329,7 +345,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
@@ -337,7 +353,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -370,7 +386,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -379,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -402,7 +423,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
ForceCacheBilling: fcb,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
return
|
||||
@@ -437,7 +465,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -446,7 +474,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -462,7 +493,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
@@ -490,9 +521,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gateway.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -515,7 +549,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
@@ -523,7 +557,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -546,18 +580,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
zap.Any("current_group_id", currentAPIKey.GroupID),
|
||||
zap.Any("fallback_group_id", fallbackGroupID),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
|
||||
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
|
||||
if err != nil {
|
||||
log.Printf("Resolve fallback group failed: %v", err)
|
||||
reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
if fallbackGroup.Platform != service.PlatformAnthropic ||
|
||||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
|
||||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
|
||||
reqLog.Warn("gateway.fallback_group_invalid",
|
||||
zap.Int64("fallback_group_id", fallbackGroup.ID),
|
||||
zap.String("fallback_platform", fallbackGroup.Platform),
|
||||
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
|
||||
)
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
@@ -591,7 +633,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -600,7 +647,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -623,9 +670,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
ForceCacheBilling: fcb,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", currentAPIKey.ID),
|
||||
zap.Any("group_id", currentAPIKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
reqLog.Debug("gateway.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
return
|
||||
}
|
||||
if !retryWithFallback {
|
||||
@@ -902,7 +961,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.failover"),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Int("retry_count", retryCount),
|
||||
).Info("gateway.single_account_backoff_waiting")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -1023,6 +1086,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.count_tokens",
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -1050,6 +1119,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
|
||||
@@ -1083,15 +1153,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
|
||||
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||
log.Printf("Forward count_tokens request failed: %v", err)
|
||||
reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
// 错误响应已在 ForwardCountTokens 中处理
|
||||
return
|
||||
}
|
||||
@@ -1355,7 +1425,10 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
log.Printf("[Gateway] billing error details: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.billing"),
|
||||
zap.Error(err),
|
||||
).Warn("gateway.billing_error_missing_message")
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -20,11 +19,13 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
@@ -143,6 +144,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gemini_v1beta.models",
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||
if !middleware.HasForcePlatform(c) {
|
||||
@@ -159,6 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -187,8 +196,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -208,6 +218,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -223,6 +234,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message := billingErrorDetails(err)
|
||||
googleError(c, status, message)
|
||||
return
|
||||
@@ -296,8 +308,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
matchedDigestChain = foundMatchedChain
|
||||
sessionBoundAccountID = foundAccountID
|
||||
geminiSessionUUID = foundUUID
|
||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
|
||||
reqLog.Info("gemini.digest_fallback_matched",
|
||||
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
|
||||
zap.Int64("account_id", foundAccountID),
|
||||
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
|
||||
)
|
||||
|
||||
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
|
||||
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
|
||||
@@ -346,7 +361,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gemini.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -358,18 +376,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
reqLog.Info("gemini.sticky_session_account_switched",
|
||||
zap.Int64("from_account_id", sessionBoundAccountID),
|
||||
zap.Int64("to_account_id", account.ID),
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||
reqLog.Info("gemini.sticky_session_binding_missing",
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
@@ -388,9 +412,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gemini.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -412,6 +439,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -420,7 +448,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -454,7 +482,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gemini.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -463,7 +496,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -482,7 +515,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account.ID,
|
||||
matchedDigestChain,
|
||||
); err != nil {
|
||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,9 +537,20 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
ForceCacheBilling: fcb,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", modelName),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
).Error("gemini.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
reqLog.Debug("gemini.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
19
backend/internal/handler/logging.go
Normal file
19
backend/internal/handler/logging.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
|
||||
base := logger.L()
|
||||
if c != nil && c.Request != nil {
|
||||
base = logger.FromContext(c.Request.Context())
|
||||
}
|
||||
|
||||
if component != "" {
|
||||
fields = append([]zap.Field{zap.String("component", component)}, fields...)
|
||||
}
|
||||
return base.With(fields...)
|
||||
}
|
||||
@@ -6,18 +6,19 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
@@ -74,6 +75,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.responses",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -113,6 +121,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
reqStream := streamResult.Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
@@ -128,13 +137,17 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
@@ -160,7 +173,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -171,9 +184,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if waitErr != nil {
|
||||
log.Printf("Increment wait count failed: %v", waitErr)
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -188,7 +202,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -207,7 +221,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
@@ -223,10 +237,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: groupID=%v model=%s tried=%d err=%v", apiKey.GroupID, reqModel, len(failedAccountIDs), err)
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
@@ -239,8 +256,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
@@ -257,22 +274,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency quick acquire failed: %v", err)
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if fastAcquired {
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -295,7 +315,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
@@ -303,7 +323,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -337,11 +357,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -363,9 +388,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
IPAddress: ip,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
).Error("openai.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
reqLog.Debug("openai.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,18 +255,33 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
c.Set(opsModelKey, model)
|
||||
c.Set(opsStreamKey, stream)
|
||||
if len(requestBody) > 0 {
|
||||
c.Set(opsRequestBodyKey, requestBody)
|
||||
}
|
||||
if c.Request != nil && model != "" {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||
if c == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
c.Set(opsAccountIDKey, accountID)
|
||||
if c.Request != nil {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
|
||||
if len(platform) > 0 {
|
||||
p := strings.TrimSpace(platform[0])
|
||||
if p != "" {
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, p)
|
||||
}
|
||||
}
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type opsCaptureWriter struct {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
@@ -18,12 +17,14 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
@@ -89,6 +90,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.sora_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -127,6 +135,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
|
||||
if !clientStream {
|
||||
if h.streamMode == "error" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
||||
@@ -160,8 +169,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -176,7 +186,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -190,7 +200,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
@@ -206,7 +216,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
log.Printf("[Sora Handler] SelectAccount failed: %v", err)
|
||||
reqLog.Warn("sora.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
@@ -215,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -226,9 +239,12 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -250,7 +266,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -276,10 +292,15 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("sora.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -298,9 +319,20 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
).Error("sora.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,9 +8,21 @@ const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
|
||||
// RequestID 为服务端生成/透传的请求 ID。
|
||||
RequestID Key = "ctx_request_id"
|
||||
|
||||
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
|
||||
ClientRequestID Key = "ctx_client_request_id"
|
||||
|
||||
// Model 请求模型标识(用于统一请求链路日志字段)。
|
||||
Model Key = "ctx_model"
|
||||
|
||||
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
|
||||
Platform Key = "ctx_platform"
|
||||
|
||||
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
|
||||
AccountID Key = "ctx_account_id"
|
||||
|
||||
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
||||
RetryCount Key = "ctx_retry_count"
|
||||
|
||||
|
||||
31
backend/internal/pkg/logger/config_adapter.go
Normal file
31
backend/internal/pkg/logger/config_adapter.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package logger
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
|
||||
return InitOptions{
|
||||
Level: cfg.Level,
|
||||
Format: cfg.Format,
|
||||
ServiceName: cfg.ServiceName,
|
||||
Environment: cfg.Environment,
|
||||
Caller: cfg.Caller,
|
||||
StacktraceLevel: cfg.StacktraceLevel,
|
||||
Output: OutputOptions{
|
||||
ToStdout: cfg.Output.ToStdout,
|
||||
ToFile: cfg.Output.ToFile,
|
||||
FilePath: cfg.Output.FilePath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: cfg.Rotation.MaxSizeMB,
|
||||
MaxBackups: cfg.Rotation.MaxBackups,
|
||||
MaxAgeDays: cfg.Rotation.MaxAgeDays,
|
||||
Compress: cfg.Rotation.Compress,
|
||||
LocalTime: cfg.Rotation.LocalTime,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: cfg.Sampling.Enabled,
|
||||
Initial: cfg.Sampling.Initial,
|
||||
Thereafter: cfg.Sampling.Thereafter,
|
||||
},
|
||||
}
|
||||
}
|
||||
518
backend/internal/pkg/logger/logger.go
Normal file
518
backend/internal/pkg/logger/logger.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
type Level = zapcore.Level
|
||||
|
||||
const (
|
||||
LevelDebug = zapcore.DebugLevel
|
||||
LevelInfo = zapcore.InfoLevel
|
||||
LevelWarn = zapcore.WarnLevel
|
||||
LevelError = zapcore.ErrorLevel
|
||||
LevelFatal = zapcore.FatalLevel
|
||||
)
|
||||
|
||||
type Sink interface {
|
||||
WriteLogEvent(event *LogEvent)
|
||||
}
|
||||
|
||||
type LogEvent struct {
|
||||
Time time.Time
|
||||
Level string
|
||||
Component string
|
||||
Message string
|
||||
LoggerName string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Init(options InitOptions) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return initLocked(options)
|
||||
}
|
||||
|
||||
func initLocked(options InitOptions) error {
|
||||
normalized := options.normalized()
|
||||
zl, al, err := buildLogger(normalized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
bridgeSlogLocked()
|
||||
bridgeStdLogLocked()
|
||||
|
||||
if prev != nil {
|
||||
_ = prev.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Reconfigure(mutator func(*InitOptions) error) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
next := initOptions
|
||||
if mutator != nil {
|
||||
if err := mutator(&next); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return initLocked(next)
|
||||
}
|
||||
|
||||
func SetLevel(level string) error {
|
||||
lv, ok := parseLevel(level)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid log level: %s", level)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
atomicLevel.SetLevel(lv)
|
||||
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
|
||||
return nil
|
||||
}
|
||||
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
component = strings.TrimSpace(component)
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return
|
||||
}
|
||||
|
||||
eventFields := make(map[string]any, len(fields)+1)
|
||||
for k, v := range fields {
|
||||
eventFields[k] = v
|
||||
}
|
||||
if component != "" {
|
||||
if _, ok := eventFields["component"]; !ok {
|
||||
eventFields["component"] = component
|
||||
}
|
||||
}
|
||||
|
||||
sink.WriteLogEvent(&LogEvent{
|
||||
Time: time.Now(),
|
||||
Level: level,
|
||||
Component: component,
|
||||
Message: message,
|
||||
LoggerName: component,
|
||||
Fields: eventFields,
|
||||
})
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
|
||||
func With(fields ...zap.Field) *zap.Logger {
|
||||
return L().With(fields...)
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeStdLogLocked() {
|
||||
if stdLogUndo != nil {
|
||||
stdLogUndo()
|
||||
stdLogUndo = nil
|
||||
}
|
||||
|
||||
prevFlags := log.Flags()
|
||||
prevPrefix := log.Prefix()
|
||||
prevWriter := log.Writer()
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
log.SetPrefix(prevPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
level, _ := parseLevel(options.Level)
|
||||
atomic := zap.NewAtomicLevelAt(level)
|
||||
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.MillisDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
var enc zapcore.Encoder
|
||||
if options.Format == "console" {
|
||||
enc = zapcore.NewConsoleEncoder(encoderCfg)
|
||||
} else {
|
||||
enc = zapcore.NewJSONEncoder(encoderCfg)
|
||||
}
|
||||
|
||||
sinkCore := newSinkCore()
|
||||
cores := make([]zapcore.Core, 0, 3)
|
||||
|
||||
if options.Output.ToStdout {
|
||||
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
|
||||
})
|
||||
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
|
||||
})
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
|
||||
}
|
||||
|
||||
if options.Output.ToFile {
|
||||
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
|
||||
if fileErr != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
|
||||
time.Now().Format(time.RFC3339Nano),
|
||||
filePath,
|
||||
fileErr,
|
||||
)
|
||||
} else {
|
||||
cores = append(cores, fileCore)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cores) == 0 {
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
|
||||
}
|
||||
|
||||
core := zapcore.NewTee(cores...)
|
||||
if options.Sampling.Enabled {
|
||||
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
|
||||
}
|
||||
core = sinkCore.Wrap(core)
|
||||
|
||||
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
|
||||
zapOpts := make([]zap.Option, 0, 5)
|
||||
if options.Caller {
|
||||
zapOpts = append(zapOpts, zap.AddCaller())
|
||||
}
|
||||
if stacktraceLevel <= zapcore.FatalLevel {
|
||||
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
|
||||
}
|
||||
|
||||
logger := zap.New(core, zapOpts...).With(
|
||||
zap.String("service", options.ServiceName),
|
||||
zap.String("env", options.Environment),
|
||||
)
|
||||
return logger, atomic, nil
|
||||
}
|
||||
|
||||
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
|
||||
filePath := options.Output.FilePath
|
||||
if strings.TrimSpace(filePath) == "" {
|
||||
filePath = resolveLogFilePath("")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, filePath, err
|
||||
}
|
||||
lj := &lumberjack.Logger{
|
||||
Filename: filePath,
|
||||
MaxSize: options.Rotation.MaxSizeMB,
|
||||
MaxBackups: options.Rotation.MaxBackups,
|
||||
MaxAge: options.Rotation.MaxAgeDays,
|
||||
Compress: options.Rotation.Compress,
|
||||
LocalTime: options.Rotation.LocalTime,
|
||||
}
|
||||
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
|
||||
}
|
||||
|
||||
type sinkCore struct {
|
||||
core zapcore.Core
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
func newSinkCore() *sinkCore {
|
||||
return &sinkCore{}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
|
||||
cp := *s
|
||||
cp.core = core
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (s *sinkCore) Enabled(level zapcore.Level) bool {
|
||||
return s.core.Enabled(level)
|
||||
}
|
||||
|
||||
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := append([]zapcore.Field{}, s.fields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &sinkCore{
|
||||
core: s.core.With(fields),
|
||||
fields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
if s.Enabled(entry.Level) {
|
||||
return ce.AddCore(entry, s)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
if err := s.core.Write(entry, fields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
for _, f := range s.fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
for _, f := range fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
|
||||
event := &LogEvent{
|
||||
Time: entry.Time,
|
||||
Level: strings.ToLower(entry.Level.String()),
|
||||
Component: entry.LoggerName,
|
||||
Message: entry.Message,
|
||||
LoggerName: entry.LoggerName,
|
||||
Fields: enc.Fields,
|
||||
}
|
||||
sink.WriteLogEvent(event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sinkCore) Sync() error {
|
||||
return s.core.Sync()
|
||||
}
|
||||
|
||||
type stdLogBridge struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newStdLogBridge(l *zap.Logger) io.Writer {
|
||||
if l == nil {
|
||||
l = zap.NewNop()
|
||||
}
|
||||
return &stdLogBridge{logger: l}
|
||||
}
|
||||
|
||||
func (b *stdLogBridge) Write(p []byte) (int, error) {
|
||||
msg := normalizeStdLogMessage(string(p))
|
||||
if msg == "" {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
level := inferStdLogLevel(msg)
|
||||
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
|
||||
|
||||
switch level {
|
||||
case LevelDebug:
|
||||
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelWarn:
|
||||
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelError, LevelFatal:
|
||||
entry.Error(msg, zap.Bool("legacy_stdlog", true))
|
||||
default:
|
||||
entry.Info(msg, zap.Bool("legacy_stdlog", true))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func normalizeStdLogMessage(raw string) string {
|
||||
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(msg), " ")
|
||||
}
|
||||
|
||||
func inferStdLogLevel(msg string) Level {
|
||||
lower := strings.ToLower(strings.TrimSpace(msg))
|
||||
if lower == "" {
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
|
||||
return LevelDebug
|
||||
}
|
||||
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
|
||||
return LevelWarn
|
||||
}
|
||||
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
|
||||
return LevelError
|
||||
}
|
||||
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
|
||||
func LegacyPrintf(component, format string, args ...any) {
|
||||
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
|
||||
if msg == "" {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
return
|
||||
}
|
||||
|
||||
l := L()
|
||||
if component != "" {
|
||||
l = l.With(zap.String("component", component))
|
||||
}
|
||||
l = l.WithOptions(zap.AddCallerSkip(1))
|
||||
|
||||
switch inferStdLogLevel(msg) {
|
||||
case LevelDebug:
|
||||
l.Debug(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelWarn:
|
||||
l.Warn(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelError, LevelFatal:
|
||||
l.Error(msg, zap.Bool("legacy_printf", true))
|
||||
default:
|
||||
l.Info(msg, zap.Bool("legacy_printf", true))
|
||||
}
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const loggerContextKey contextKey = "ctx_logger"
|
||||
|
||||
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if l == nil {
|
||||
l = L()
|
||||
}
|
||||
return context.WithValue(ctx, loggerContextKey, l)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *zap.Logger {
|
||||
if ctx == nil {
|
||||
return L()
|
||||
}
|
||||
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
|
||||
return l
|
||||
}
|
||||
return L()
|
||||
}
|
||||
192
backend/internal/pkg/logger/logger_test.go
Normal file
192
backend/internal/pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInit_DualOutput(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: logPath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 2,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("dual-output-info")
|
||||
L().Warn("dual-output-warn")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "dual-output-info") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "dual-output-warn") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
|
||||
fileBytes, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read log file: %v", err)
|
||||
}
|
||||
fileText := string(fileBytes)
|
||||
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
|
||||
t.Fatalf("file missing logs: %s", fileText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
_, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
|
||||
}
|
||||
|
||||
_ = stderrW.Close()
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
|
||||
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
_, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Caller: true,
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("caller-check")
|
||||
Sync()
|
||||
_ = stdoutW.Close()
|
||||
logBytes, _ := io.ReadAll(stdoutR)
|
||||
|
||||
var line string
|
||||
for _, item := range strings.Split(string(logBytes), "\n") {
|
||||
if strings.Contains(item, "caller-check") {
|
||||
line = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
t.Fatalf("log output missing caller-check: %s", string(logBytes))
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &payload); err != nil {
|
||||
t.Fatalf("parse log json failed: %v, line=%s", err, line)
|
||||
}
|
||||
caller, _ := payload["caller"].(string)
|
||||
if !strings.Contains(caller, "logger_test.go:") {
|
||||
t.Fatalf("caller should point to this test file, got: %s", caller)
|
||||
}
|
||||
}
|
||||
161
backend/internal/pkg/logger/options.go
Normal file
161
backend/internal/pkg/logger/options.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultContainerLogPath 为容器内默认日志文件路径。
|
||||
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
|
||||
defaultLogFilename = "sub2api.log"
|
||||
)
|
||||
|
||||
type InitOptions struct {
|
||||
Level string
|
||||
Format string
|
||||
ServiceName string
|
||||
Environment string
|
||||
Caller bool
|
||||
StacktraceLevel string
|
||||
Output OutputOptions
|
||||
Rotation RotationOptions
|
||||
Sampling SamplingOptions
|
||||
}
|
||||
|
||||
type OutputOptions struct {
|
||||
ToStdout bool
|
||||
ToFile bool
|
||||
FilePath string
|
||||
}
|
||||
|
||||
type RotationOptions struct {
|
||||
MaxSizeMB int
|
||||
MaxBackups int
|
||||
MaxAgeDays int
|
||||
Compress bool
|
||||
LocalTime bool
|
||||
}
|
||||
|
||||
type SamplingOptions struct {
|
||||
Enabled bool
|
||||
Initial int
|
||||
Thereafter int
|
||||
}
|
||||
|
||||
func (o InitOptions) normalized() InitOptions {
|
||||
out := o
|
||||
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
|
||||
if out.Level == "" {
|
||||
out.Level = "info"
|
||||
}
|
||||
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
|
||||
if out.Format == "" {
|
||||
out.Format = "console"
|
||||
}
|
||||
out.ServiceName = strings.TrimSpace(out.ServiceName)
|
||||
if out.ServiceName == "" {
|
||||
out.ServiceName = "sub2api"
|
||||
}
|
||||
out.Environment = strings.TrimSpace(out.Environment)
|
||||
if out.Environment == "" {
|
||||
out.Environment = "production"
|
||||
}
|
||||
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
|
||||
if out.StacktraceLevel == "" {
|
||||
out.StacktraceLevel = "error"
|
||||
}
|
||||
if !out.Output.ToStdout && !out.Output.ToFile {
|
||||
out.Output.ToStdout = true
|
||||
}
|
||||
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
|
||||
if out.Rotation.MaxSizeMB <= 0 {
|
||||
out.Rotation.MaxSizeMB = 100
|
||||
}
|
||||
if out.Rotation.MaxBackups < 0 {
|
||||
out.Rotation.MaxBackups = 10
|
||||
}
|
||||
if out.Rotation.MaxAgeDays < 0 {
|
||||
out.Rotation.MaxAgeDays = 7
|
||||
}
|
||||
if out.Sampling.Enabled {
|
||||
if out.Sampling.Initial <= 0 {
|
||||
out.Sampling.Initial = 100
|
||||
}
|
||||
if out.Sampling.Thereafter <= 0 {
|
||||
out.Sampling.Thereafter = 100
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveLogFilePath(explicit string) string {
|
||||
explicit = strings.TrimSpace(explicit)
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
|
||||
if dataDir != "" {
|
||||
return filepath.Join(dataDir, "logs", defaultLogFilename)
|
||||
}
|
||||
return DefaultContainerLogPath
|
||||
}
|
||||
|
||||
func bootstrapOptions() InitOptions {
|
||||
return InitOptions{
|
||||
Level: "info",
|
||||
Format: "console",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "bootstrap",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAgeDays: 7,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "debug":
|
||||
return LevelDebug, true
|
||||
case "info":
|
||||
return LevelInfo, true
|
||||
case "warn":
|
||||
return LevelWarn, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
default:
|
||||
return LevelInfo, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseStacktraceLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "none":
|
||||
return LevelFatal + 1, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
case "fatal":
|
||||
return LevelFatal, true
|
||||
default:
|
||||
return LevelError, false
|
||||
}
|
||||
}
|
||||
|
||||
func samplingTick() time.Duration {
|
||||
return time.Second
|
||||
}
|
||||
102
backend/internal/pkg/logger/options_test.go
Normal file
102
backend/internal/pkg/logger/options_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestResolveLogFilePath_Default(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
got := resolveLogFilePath("")
|
||||
if got != DefaultContainerLogPath {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
|
||||
got := resolveLogFilePath("")
|
||||
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
|
||||
if got != want {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/ignore")
|
||||
got := resolveLogFilePath("/var/log/custom.log")
|
||||
if got != "/var/log/custom.log" {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := InitOptions{
|
||||
Level: "TRACE",
|
||||
Format: "TEXT",
|
||||
ServiceName: "",
|
||||
Environment: "",
|
||||
StacktraceLevel: "panic",
|
||||
Output: OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 0,
|
||||
MaxBackups: -1,
|
||||
MaxAgeDays: -1,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: true,
|
||||
Initial: 0,
|
||||
Thereafter: 0,
|
||||
},
|
||||
}
|
||||
out := opts.normalized()
|
||||
if out.Level != "trace" {
|
||||
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
|
||||
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
|
||||
}
|
||||
if !out.Output.ToStdout {
|
||||
t.Fatalf("normalized output should fallback to stdout")
|
||||
}
|
||||
if out.Output.FilePath != DefaultContainerLogPath {
|
||||
t.Fatalf("normalized file path = %q", out.Output.FilePath)
|
||||
}
|
||||
if out.Rotation.MaxSizeMB != 100 {
|
||||
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
|
||||
}
|
||||
if out.Rotation.MaxBackups != 10 {
|
||||
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
|
||||
}
|
||||
if out.Rotation.MaxAgeDays != 7 {
|
||||
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
|
||||
}
|
||||
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
|
||||
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := bootstrapOptions()
|
||||
opts.Output.ToFile = true
|
||||
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
MessageKey: "msg",
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
}
|
||||
encoder := zapcore.NewJSONEncoder(encoderCfg)
|
||||
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
|
||||
if err == nil {
|
||||
t.Fatalf("buildFileCore() expected error for invalid path")
|
||||
}
|
||||
}
|
||||
132
backend/internal/pkg/logger/slog_handler.go
Normal file
132
backend/internal/pkg/logger/slog_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type slogZapHandler struct {
|
||||
logger *zap.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &slogZapHandler{
|
||||
logger: logger,
|
||||
attrs: make([]slog.Attr, 0, 8),
|
||||
groups: make([]string, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch {
|
||||
case level >= slog.LevelError:
|
||||
return h.logger.Core().Enabled(LevelError)
|
||||
case level >= slog.LevelWarn:
|
||||
return h.logger.Core().Enabled(LevelWarn)
|
||||
case level <= slog.LevelDebug:
|
||||
return h.logger.Core().Enabled(LevelDebug)
|
||||
default:
|
||||
return h.logger.Core().Enabled(LevelInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
|
||||
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
fields = append(fields, slogAttrToZapField(h.groups, attr))
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
next := *h
|
||||
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
|
||||
return &next
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h
|
||||
}
|
||||
next := *h
|
||||
next.groups = append(append([]string{}, h.groups...), name)
|
||||
return &next
|
||||
}
|
||||
|
||||
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
|
||||
fields := make([]zap.Field, 0, len(attrs))
|
||||
for _, attr := range attrs {
|
||||
fields = append(fields, slogAttrToZapField(groups, attr))
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
|
||||
if len(groups) > 0 {
|
||||
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
|
||||
}
|
||||
value := attr.Value.Resolve()
|
||||
switch value.Kind() {
|
||||
case slog.KindBool:
|
||||
return zap.Bool(attr.Key, value.Bool())
|
||||
case slog.KindInt64:
|
||||
return zap.Int64(attr.Key, value.Int64())
|
||||
case slog.KindUint64:
|
||||
return zap.Uint64(attr.Key, value.Uint64())
|
||||
case slog.KindFloat64:
|
||||
return zap.Float64(attr.Key, value.Float64())
|
||||
case slog.KindDuration:
|
||||
return zap.Duration(attr.Key, value.Duration())
|
||||
case slog.KindTime:
|
||||
return zap.Time(attr.Key, value.Time())
|
||||
case slog.KindString:
|
||||
return zap.String(attr.Key, value.String())
|
||||
case slog.KindGroup:
|
||||
groupFields := make([]zap.Field, 0, len(value.Group()))
|
||||
for _, nested := range value.Group() {
|
||||
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
|
||||
}
|
||||
return zap.Object(attr.Key, zapObjectFields(groupFields))
|
||||
case slog.KindAny:
|
||||
if t, ok := value.Any().(time.Time); ok {
|
||||
return zap.Time(attr.Key, t)
|
||||
}
|
||||
return zap.Any(attr.Key, value.Any())
|
||||
default:
|
||||
return zap.String(attr.Key, value.String())
|
||||
}
|
||||
}
|
||||
|
||||
type zapObjectFields []zap.Field
|
||||
|
||||
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
|
||||
for _, field := range z {
|
||||
field.AddTo(enc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
90
backend/internal/pkg/logger/slog_handler_test.go
Normal file
90
backend/internal/pkg/logger/slog_handler_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type captureState struct {
|
||||
writes []capturedWrite
|
||||
}
|
||||
|
||||
type capturedWrite struct {
|
||||
entry zapcore.Entry
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
type captureCore struct {
|
||||
state *captureState
|
||||
withFields []zapcore.Field
|
||||
}
|
||||
|
||||
func newCaptureCore() *captureCore {
|
||||
return &captureCore{state: &captureState{}}
|
||||
}
|
||||
|
||||
func (c *captureCore) Enabled(zapcore.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
nextFields = append(nextFields, c.withFields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &captureCore{
|
||||
state: c.state,
|
||||
withFields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
return ce.AddCore(entry, c)
|
||||
}
|
||||
|
||||
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
allFields = append(allFields, c.withFields...)
|
||||
allFields = append(allFields, fields...)
|
||||
c.state.writes = append(c.state.writes, capturedWrite{
|
||||
entry: entry,
|
||||
fields: allFields,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *captureCore) Sync() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
|
||||
core := newCaptureCore()
|
||||
handler := newSlogZapHandler(zap.New(core))
|
||||
|
||||
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
|
||||
record.AddAttrs(slog.String("component", "http.access"))
|
||||
|
||||
if err := handler.Handle(context.Background(), record); err != nil {
|
||||
t.Fatalf("handle slog record: %v", err)
|
||||
}
|
||||
if len(core.state.writes) != 1 {
|
||||
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
|
||||
}
|
||||
|
||||
var hasComponent bool
|
||||
for _, field := range core.state.writes[0].fields {
|
||||
if field.Key == "time" {
|
||||
t.Fatalf("unexpected duplicate time field in slog adapter output")
|
||||
}
|
||||
if field.Key == "component" {
|
||||
hasComponent = true
|
||||
}
|
||||
}
|
||||
if !hasComponent {
|
||||
t.Fatalf("component field should be preserved")
|
||||
}
|
||||
}
|
||||
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferStdLogLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want Level
|
||||
}{
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := inferStdLogLevel(tc.msg)
|
||||
if got != tc.want {
|
||||
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStdLogMessage(t *testing.T) {
|
||||
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
|
||||
got := normalizeStdLogMessage(raw)
|
||||
want := "[TokenRefresh] cycle complete total=1 failed=0"
|
||||
if got != want {
|
||||
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdLogBridgeRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("service started")
|
||||
log.Printf("Warning: queue full")
|
||||
log.Printf("Forward request failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "service started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Forward request failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
|
||||
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyPrintfRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
LegacyPrintf("service.test", "request started")
|
||||
LegacyPrintf("service.test", "Warning: queue full")
|
||||
LegacyPrintf("service.test", "forward failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "request started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "forward failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
|
||||
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
|
||||
t.Fatalf("stderr missing component field: %s", stderrText)
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -525,7 +525,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -560,7 +560,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -575,7 +575,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -595,11 +595,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -623,7 +623,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -640,7 +640,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -713,7 +713,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -821,7 +821,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -868,7 +868,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -882,7 +882,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -901,7 +901,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -920,7 +920,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -936,7 +936,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -960,7 +960,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -984,7 +984,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1006,7 +1006,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1021,7 +1021,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
@@ -1049,7 +1049,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1085,7 +1085,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1179,7 +1179,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
@@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
@@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
@@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
@@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
@@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
@@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -72,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
@@ -152,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
@@ -406,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
@@ -500,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -938,6 +939,243 @@ WHERE id = $1`
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
|
||||
"ops_system_logs",
|
||||
"created_at",
|
||||
"level",
|
||||
"component",
|
||||
"message",
|
||||
"request_id",
|
||||
"client_request_id",
|
||||
"user_id",
|
||||
"account_id",
|
||||
"platform",
|
||||
"model",
|
||||
"extra",
|
||||
))
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var inserted int64
|
||||
for _, input := range inputs {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
component := strings.TrimSpace(input.Component)
|
||||
level := strings.ToLower(strings.TrimSpace(input.Level))
|
||||
message := strings.TrimSpace(input.Message)
|
||||
if level == "" || message == "" {
|
||||
continue
|
||||
}
|
||||
if component == "" {
|
||||
component = "app"
|
||||
}
|
||||
extra := strings.TrimSpace(input.ExtraJSON)
|
||||
if extra == "" {
|
||||
extra = "{}"
|
||||
}
|
||||
if _, err := stmt.ExecContext(
|
||||
ctx,
|
||||
createdAt.UTC(),
|
||||
level,
|
||||
component,
|
||||
message,
|
||||
opsNullString(input.RequestID),
|
||||
opsNullString(input.ClientRequestID),
|
||||
opsNullInt64(input.UserID),
|
||||
opsNullInt64(input.AccountID),
|
||||
opsNullString(input.Platform),
|
||||
opsNullString(input.Model),
|
||||
extra,
|
||||
); err != nil {
|
||||
_ = stmt.Close()
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
_ = stmt.Close()
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
if err := stmt.Close(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
return inserted, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogFilter{}
|
||||
}
|
||||
|
||||
page := filter.Page
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := filter.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50
|
||||
}
|
||||
if pageSize > 200 {
|
||||
pageSize = 200
|
||||
}
|
||||
|
||||
where, args, _ := buildOpsSystemLogsWhere(filter)
|
||||
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
argsWithLimit := append(args, pageSize, offset)
|
||||
query := `
|
||||
SELECT
|
||||
l.id,
|
||||
l.created_at,
|
||||
l.level,
|
||||
COALESCE(l.component, ''),
|
||||
COALESCE(l.message, ''),
|
||||
COALESCE(l.request_id, ''),
|
||||
COALESCE(l.client_request_id, ''),
|
||||
l.user_id,
|
||||
l.account_id,
|
||||
COALESCE(l.platform, ''),
|
||||
COALESCE(l.model, ''),
|
||||
COALESCE(l.extra::text, '{}')
|
||||
FROM ops_system_logs l
|
||||
` + where + `
|
||||
ORDER BY l.created_at DESC, l.id DESC
|
||||
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
logs := make([]*service.OpsSystemLog, 0, pageSize)
|
||||
for rows.Next() {
|
||||
item := &service.OpsSystemLog{}
|
||||
var userID sql.NullInt64
|
||||
var accountID sql.NullInt64
|
||||
var extraRaw string
|
||||
if err := rows.Scan(
|
||||
&item.ID,
|
||||
&item.CreatedAt,
|
||||
&item.Level,
|
||||
&item.Component,
|
||||
&item.Message,
|
||||
&item.RequestID,
|
||||
&item.ClientRequestID,
|
||||
&userID,
|
||||
&accountID,
|
||||
&item.Platform,
|
||||
&item.Model,
|
||||
&extraRaw,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID.Valid {
|
||||
v := userID.Int64
|
||||
item.UserID = &v
|
||||
}
|
||||
if accountID.Valid {
|
||||
v := accountID.Int64
|
||||
item.AccountID = &v
|
||||
}
|
||||
extraRaw = strings.TrimSpace(extraRaw)
|
||||
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
|
||||
extra := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
|
||||
item.Extra = extra
|
||||
}
|
||||
}
|
||||
logs = append(logs, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsSystemLogList{
|
||||
Logs: logs,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogCleanupFilter{}
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
|
||||
if !hasConstraint {
|
||||
return 0, fmt.Errorf("cleanup requires at least one filter condition")
|
||||
}
|
||||
|
||||
query := "DELETE FROM ops_system_logs l " + where
|
||||
res, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO ops_system_log_cleanup_audits (
|
||||
created_at,
|
||||
operator_id,
|
||||
conditions,
|
||||
deleted_rows
|
||||
) VALUES ($1,$2,$3,$4)
|
||||
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
|
||||
return err
|
||||
}
|
||||
|
||||
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
clauses := make([]string, 0, 12)
|
||||
args := make([]any, 0, 12)
|
||||
@@ -1053,6 +1291,95 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
|
||||
clauses := make([]string, 0, 10)
|
||||
args := make([]any, 0, 10)
|
||||
clauses = append(clauses, "1=1")
|
||||
hasConstraint := false
|
||||
|
||||
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
args = append(args, filter.StartTime.UTC())
|
||||
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
args = append(args, filter.EndTime.UTC())
|
||||
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter != nil {
|
||||
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Component); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.RequestID); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter.UserID != nil && *filter.UserID > 0 {
|
||||
args = append(args, *filter.UserID)
|
||||
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
args = append(args, *filter.AccountID)
|
||||
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Platform); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Model); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Query); v != "" {
|
||||
like := "%" + v + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
|
||||
hasConstraint = true
|
||||
}
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
|
||||
}
|
||||
|
||||
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogCleanupFilter{}
|
||||
}
|
||||
listFilter := &service.OpsSystemLogFilter{
|
||||
StartTime: filter.StartTime,
|
||||
EndTime: filter.EndTime,
|
||||
Level: filter.Level,
|
||||
Component: filter.Component,
|
||||
RequestID: filter.RequestID,
|
||||
ClientRequestID: filter.ClientRequestID,
|
||||
UserID: filter.UserID,
|
||||
AccountID: filter.AccountID,
|
||||
Platform: filter.Platform,
|
||||
Model: filter.Model,
|
||||
Query: filter.Query,
|
||||
}
|
||||
return buildOpsSystemLogsWhere(listFilter)
|
||||
}
|
||||
|
||||
// Helpers for nullable args
|
||||
func opsNullString(v any) any {
|
||||
switch s := v.(type) {
|
||||
|
||||
86
backend/internal/repository/ops_repo_system_logs_test.go
Normal file
86
backend/internal/repository/ops_repo_system_logs_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
|
||||
userID := int64(12)
|
||||
accountID := int64(34)
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: "warn",
|
||||
Component: "http.access",
|
||||
RequestID: "req-1",
|
||||
ClientRequestID: "creq-1",
|
||||
UserID: &userID,
|
||||
AccountID: &accountID,
|
||||
Platform: "openai",
|
||||
Model: "gpt-5",
|
||||
Query: "timeout",
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 11 {
|
||||
t.Fatalf("args len = %d, want 11", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
|
||||
if hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=false")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Fatalf("args len = %d, want 0", len(args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
userID := int64(9)
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
ClientRequestID: "creq-9",
|
||||
UserID: &userID,
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if len(args) != 2 {
|
||||
t.Fatalf("args len = %d, want 2", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s string, sub string) bool {
|
||||
return strings.Contains(s, sub)
|
||||
}
|
||||
@@ -2,10 +2,13 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id))
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
||||
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
@@ -24,38 +26,71 @@ func Logger() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 请求方法
|
||||
method := c.Request.Method
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 协议版本
|
||||
protocol := c.Request.Proto
|
||||
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
|
||||
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
|
||||
model, _ := c.Request.Context().Value(ctxkey.Model).(string)
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
protocol,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "http.access"),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Int64("latency_ms", latency.Milliseconds()),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("protocol", protocol),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
}
|
||||
if hasAccountID && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
if platform != "" {
|
||||
fields = append(fields, zap.String("platform", platform))
|
||||
}
|
||||
if model != "" {
|
||||
fields = append(fields, zap.String("model", model))
|
||||
}
|
||||
|
||||
l := logger.FromContext(c.Request.Context()).With(fields...)
|
||||
l.Info("http request completed", zap.Time("completed_at", endTime))
|
||||
// 当全局日志级别高于 info(如 warn/error)时,access info 不会进入 zap core,
|
||||
// 这里补写一次 sink,保证 ops 系统日志仍可索引关键访问轨迹。
|
||||
if !logger.L().Core().Enabled(logger.LevelInfo) {
|
||||
sinkFields := map[string]any{
|
||||
"component": "http.access",
|
||||
"status_code": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"client_ip": clientIP,
|
||||
"protocol": protocol,
|
||||
"method": method,
|
||||
"path": path,
|
||||
"completed_at": endTime,
|
||||
}
|
||||
if requestID, ok := c.Request.Context().Value(ctxkey.RequestID).(string); ok && requestID != "" {
|
||||
sinkFields["request_id"] = requestID
|
||||
}
|
||||
if clientRequestID, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string); ok && clientRequestID != "" {
|
||||
sinkFields["client_request_id"] = clientRequestID
|
||||
}
|
||||
if hasAccountID && accountID > 0 {
|
||||
sinkFields["account_id"] = accountID
|
||||
}
|
||||
if platform != "" {
|
||||
sinkFields["platform"] = platform
|
||||
}
|
||||
if model != "" {
|
||||
sinkFields["model"] = model
|
||||
}
|
||||
logger.WriteSinkEvent("info", "http.access", "http request completed", sinkFields)
|
||||
}
|
||||
|
||||
// 如果有错误,额外记录错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
log.Printf("[GIN] Errors: %v", c.Errors.String())
|
||||
l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
249
backend/internal/server/middleware/request_access_logger_test.go
Normal file
249
backend/internal/server/middleware/request_access_logger_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.events = append(s.events, event)
|
||||
}
|
||||
|
||||
func (s *testLogSink) list() []*logger.LogEvent {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*logger.LogEvent, len(s.events))
|
||||
copy(out, s.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
|
||||
return initMiddlewareTestLoggerWithLevel(t, "debug")
|
||||
}
|
||||
|
||||
func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink {
|
||||
t.Helper()
|
||||
level = strings.TrimSpace(level)
|
||||
if level == "" {
|
||||
level = "debug"
|
||||
}
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: level,
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
sink := &testLogSink{}
|
||||
logger.SetSink(sink)
|
||||
t.Cleanup(func() {
|
||||
logger.SetSink(nil)
|
||||
})
|
||||
return sink
|
||||
}
|
||||
|
||||
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if !ok || reqID == "" {
|
||||
t.Fatalf("request_id missing in context")
|
||||
}
|
||||
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
|
||||
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if w.Header().Get(requestIDHeader) == "" {
|
||||
t.Fatalf("X-Request-ID should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if reqID != "rid-fixed" {
|
||||
t.Fatalf("request_id=%q, want rid-fixed", reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set(requestIDHeader, "rid-fixed")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
|
||||
t.Fatalf("header=%q, want rid-fixed", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.Use(func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
|
||||
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
})
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
if len(events) == 0 {
|
||||
t.Fatalf("expected at least one log event")
|
||||
}
|
||||
found := false
|
||||
for _, event := range events {
|
||||
if event == nil || event.Message != "http request completed" {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
switch v := event.Fields["status_code"].(type) {
|
||||
case int:
|
||||
if v != http.StatusCreated {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
case int64:
|
||||
if v != int64(http.StatusCreated) {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("status_code type mismatch: %T", v)
|
||||
}
|
||||
switch v := event.Fields["account_id"].(type) {
|
||||
case int64:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
case int:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("account_id type mismatch: %T", v)
|
||||
}
|
||||
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
|
||||
t.Fatalf("platform/model mismatch: %+v", event.Fields)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("access log event not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_HealthPathSkipped(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if len(sink.list()) != 0 {
|
||||
t.Fatalf("health endpoint should not write access log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogStillIndexedWhenLevelWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLoggerWithLevel(t, "warn")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.Use(Logger())
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
if len(events) == 0 {
|
||||
t.Fatalf("expected access log event to be indexed when level=warn")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, event := range events {
|
||||
if event == nil || event.Message != "http request completed" {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
if event.Level != "info" {
|
||||
t.Fatalf("event level=%q, want info", event.Level)
|
||||
}
|
||||
if event.Component != "http.access" && event.Fields["component"] != "http.access" {
|
||||
t.Fatalf("event component mismatch: component=%q fields=%v", event.Component, event.Fields["component"])
|
||||
}
|
||||
if _, ok := event.Fields["status_code"]; !ok {
|
||||
t.Fatalf("status_code field missing: %+v", event.Fields)
|
||||
}
|
||||
if _, ok := event.Fields["request_id"]; !ok {
|
||||
t.Fatalf("request_id field missing: %+v", event.Fields)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("access log event not found")
|
||||
}
|
||||
}
|
||||
45
backend/internal/server/middleware/request_logger.go
Normal file
45
backend/internal/server/middleware/request_logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const requestIDHeader = "X-Request-ID"
|
||||
|
||||
// RequestLogger 在请求入口注入 request-scoped logger。
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
c.Header(requestIDHeader, requestID)
|
||||
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID)
|
||||
clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string)
|
||||
|
||||
requestLogger := logger.With(
|
||||
zap.String("component", "http"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("client_request_id", strings.TrimSpace(clientRequestID)),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,7 @@ func SetupRouter(
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
||||
|
||||
@@ -101,6 +101,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
|
||||
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
|
||||
runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig)
|
||||
runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig)
|
||||
runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig)
|
||||
}
|
||||
|
||||
// Advanced settings (DB-backed)
|
||||
@@ -144,6 +147,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// Request drilldown (success + error)
|
||||
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
||||
|
||||
// Indexed system logs
|
||||
ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs)
|
||||
ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs)
|
||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||
|
||||
// Dashboard (vNext - raw path for MVP)
|
||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -361,7 +361,7 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
@@ -379,7 +379,7 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
|
||||
if s.userGroupRateRepo != nil {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err)
|
||||
} else {
|
||||
user.GroupRates = rates
|
||||
}
|
||||
@@ -457,7 +457,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
// 同步用户专属分组倍率
|
||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
|
||||
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
|
||||
logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,7 +471,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &RedeemCode{
|
||||
@@ -484,7 +484,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -501,7 +501,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return errors.New("cannot delete admin user")
|
||||
}
|
||||
if err := s.userRepo.Delete(ctx, id); err != nil {
|
||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err)
|
||||
return err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
@@ -544,7 +544,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -552,7 +552,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
if balanceDiff != 0 {
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -568,7 +568,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
adjustmentRecord.UsedAt = &now
|
||||
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1026,7 +1026,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
defer cancel()
|
||||
for _, userID := range affectedUserIDs {
|
||||
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||
logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -1144,7 +1144,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
|
||||
// 只记录警告日志,不阻塞账号创建
|
||||
log.Printf("[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1779,7 +1779,7 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
|
||||
|
||||
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
|
||||
if err != nil {
|
||||
log.Printf("Warning: load proxy latency cache failed: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1808,7 +1808,7 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
|
||||
return
|
||||
}
|
||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||
log.Printf("Warning: store proxy latency cache failed: %v", err)
|
||||
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -154,7 +154,7 @@ type smartRetryResult struct {
|
||||
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
|
||||
// "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429)
|
||||
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||||
}
|
||||
|
||||
@@ -174,13 +174,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
if rateLimitDuration <= 0 {
|
||||
rateLimitDuration = antigravityDefaultRateLimitDuration
|
||||
}
|
||||
log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
|
||||
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200))
|
||||
|
||||
resetAt := time.Now().Add(rateLimitDuration)
|
||||
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
|
||||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||
log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
|
||||
} else {
|
||||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||||
}
|
||||
@@ -202,12 +202,12 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
var lastRetryBody []byte
|
||||
|
||||
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
|
||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
|
||||
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_smart_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
}
|
||||
@@ -215,7 +215,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
// 智能重试:创建新请求
|
||||
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||||
if err != nil {
|
||||
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
|
||||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
@@ -229,13 +229,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||||
}
|
||||
|
||||
// 网络错误时,继续重试
|
||||
if retryErr != nil || retryResp == nil {
|
||||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
|
||||
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
|
||||
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
|
||||
log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
|
||||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
@@ -283,15 +283,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
||||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
||||
|
||||
resetAt := time.Now().Add(rateLimitDuration)
|
||||
if p.accountRepo != nil && modelName != "" {
|
||||
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
|
||||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
|
||||
} else {
|
||||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
|
||||
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration)
|
||||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||||
}
|
||||
@@ -346,7 +346,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
waitDuration = antigravitySmartRetryMinWait
|
||||
}
|
||||
|
||||
log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
|
||||
p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration)
|
||||
|
||||
var lastRetryResp *http.Response
|
||||
@@ -358,19 +358,19 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait {
|
||||
remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited
|
||||
if remaining <= 0 {
|
||||
log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
|
||||
p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait)
|
||||
break
|
||||
}
|
||||
waitDuration = remaining
|
||||
}
|
||||
|
||||
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
|
||||
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
}
|
||||
@@ -379,13 +379,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
// 创建新请求
|
||||
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||||
if err != nil {
|
||||
log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
|
||||
break
|
||||
}
|
||||
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
|
||||
p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited)
|
||||
// 关闭之前的响应
|
||||
if lastRetryResp != nil {
|
||||
@@ -396,7 +396,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
|
||||
// 网络错误时继续重试
|
||||
if retryErr != nil || retryResp == nil {
|
||||
log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: network_error attempt=%d/%d error=%v",
|
||||
p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr)
|
||||
continue
|
||||
}
|
||||
@@ -430,7 +430,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
if retryBody == nil {
|
||||
retryBody = respBody
|
||||
}
|
||||
log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
|
||||
p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||
|
||||
return &smartRetryResult{
|
||||
@@ -453,10 +453,10 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||
if isSingleAccountRetry(p.ctx) {
|
||||
log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
} else {
|
||||
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
return nil, &AntigravityAccountSwitchError{
|
||||
OriginalAccountID: p.account.ID,
|
||||
@@ -492,7 +492,7 @@ urlFallbackLoop:
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||||
return nil, p.ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -522,18 +522,18 @@ urlFallbackLoop:
|
||||
Message: safeErr,
|
||||
})
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retries_exhausted error=%v", p.prefix, err)
|
||||
setOpsUpstreamError(p.c, 0, safeErr, "")
|
||||
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
|
||||
}
|
||||
@@ -590,9 +590,9 @@ urlFallbackLoop:
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
continue
|
||||
@@ -600,7 +600,7 @@ urlFallbackLoop:
|
||||
|
||||
// 重试用尽,标记账户限流
|
||||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
@@ -624,9 +624,9 @@ urlFallbackLoop:
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
continue
|
||||
@@ -924,14 +924,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, lastErr
|
||||
@@ -946,7 +946,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1328,7 +1328,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||||
if txErr != nil {
|
||||
@@ -1361,7 +1361,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Kind: "signature_retry_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
})
|
||||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1380,7 +1380,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if retryResp.Request != nil && retryResp.Request.URL != nil {
|
||||
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
|
||||
}
|
||||
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
|
||||
}
|
||||
kind := "signature_retry"
|
||||
if strings.TrimSpace(stage.name) != "" {
|
||||
@@ -1433,7 +1433,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||||
logBody, maxBytes := s.getLogConfig()
|
||||
if logBody {
|
||||
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -1487,7 +1487,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 客户端要求流式,直接透传转换
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -1497,7 +1497,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 客户端要求非流式,收集流式响应后转换返回
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -1889,9 +1889,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 清理 Schema
|
||||
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||||
injectedBody = cleanedBody
|
||||
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||||
} else {
|
||||
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Failed to clean schema: %v", err)
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
@@ -1953,7 +1953,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||||
if err == nil {
|
||||
@@ -2020,7 +2020,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
|
||||
c.Data(resp.StatusCode, contentType, unwrappedForOps)
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
@@ -2039,7 +2039,7 @@ handleSuccess:
|
||||
// 客户端要求流式,直接透传
|
||||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -2049,7 +2049,7 @@ handleSuccess:
|
||||
// 客户端要求非流式,收集流式响应后返回
|
||||
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -2128,13 +2128,13 @@ func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, a
|
||||
}
|
||||
// 直接使用官方模型 ID 作为 key,不再转换为 scope
|
||||
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
|
||||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||
return false
|
||||
}
|
||||
if afterSmartRetry {
|
||||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||||
} else {
|
||||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -2241,7 +2241,7 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
|
||||
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
|
||||
dur, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
|
||||
continue
|
||||
}
|
||||
retryDelay = dur
|
||||
@@ -2342,7 +2342,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
||||
|
||||
// < antigravityRateLimitThreshold: 等待后重试
|
||||
if info.RetryDelay < antigravityRateLimitThreshold {
|
||||
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||||
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
||||
return &handleModelRateLimitResult{
|
||||
Handled: true,
|
||||
@@ -2367,12 +2367,12 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
||||
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
|
||||
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
|
||||
resetAt := time.Now().Add(info.RetryDelay)
|
||||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
|
||||
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
|
||||
|
||||
// 设置模型限流状态(数据库)
|
||||
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
|
||||
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
|
||||
}
|
||||
|
||||
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
|
||||
@@ -2408,7 +2408,7 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte
|
||||
|
||||
// 更新 Redis 快照
|
||||
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
|
||||
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2447,7 +2447,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
// 429:尝试解析模型级限流,解析失败时兜底为账号级限流
|
||||
if statusCode == 429 {
|
||||
if logBody, maxBytes := s.getLogConfig(); logBody {
|
||||
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
|
||||
}
|
||||
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
@@ -2458,9 +2458,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
if modelKey != "" {
|
||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||
log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
|
||||
} else {
|
||||
log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
|
||||
prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
|
||||
s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra)
|
||||
}
|
||||
@@ -2469,10 +2469,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
|
||||
// 无法解析模型 key,兜底为账号级限流
|
||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
|
||||
prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2482,7 +2482,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
}
|
||||
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
if shouldDisable {
|
||||
log.Printf("%s status=%d marked_error", prefix, statusCode)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d marked_error", prefix, statusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2556,18 +2556,18 @@ func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected
|
||||
|
||||
func (cw *antigravityClientWriter) markDisconnected() {
|
||||
cw.disconnected = true
|
||||
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
|
||||
}
|
||||
|
||||
// handleStreamReadError 处理上游读取错误的通用逻辑。
|
||||
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
|
||||
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Context canceled during streaming (%s), returning collected usage", prefix)
|
||||
return true, true
|
||||
}
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
|
||||
return true, true
|
||||
}
|
||||
return false, false
|
||||
@@ -2672,7 +2672,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
@@ -2705,10 +2705,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2733,10 +2733,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
@@ -2819,7 +2819,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
}
|
||||
return nil, ev.err
|
||||
}
|
||||
@@ -2864,10 +2864,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2894,7 +2894,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity non-stream)")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity non-stream)")
|
||||
return nil, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
@@ -2905,7 +2905,7 @@ returnResponse:
|
||||
|
||||
// 处理空响应情况
|
||||
if last == nil && lastWithParts == nil {
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
}
|
||||
|
||||
// 如果收集到了图片 parts,需要合并到最终响应中
|
||||
@@ -3120,7 +3120,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
|
||||
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
|
||||
if logBody {
|
||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
@@ -3262,7 +3262,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
}
|
||||
return nil, ev.err
|
||||
}
|
||||
@@ -3311,7 +3311,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity claude non-stream)")
|
||||
return nil, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
@@ -3322,7 +3322,7 @@ returnResponse:
|
||||
|
||||
// 处理空响应情况
|
||||
if last == nil && lastWithParts == nil {
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||
}
|
||||
|
||||
@@ -3340,7 +3340,7 @@ returnResponse:
|
||||
// 转换 Gemini 响应为 Claude 格式
|
||||
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
@@ -3475,7 +3475,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
@@ -3499,10 +3499,10 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity claude), returning collected usage")
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
@@ -3702,7 +3702,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s upstream request failed: %v", prefix, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -3760,7 +3760,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
|
||||
// 构建计费结果
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
@@ -3846,7 +3846,7 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
|
||||
}
|
||||
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream read error (antigravity upstream): %v", ev.err)
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
|
||||
@@ -3870,10 +3870,10 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity upstream)")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -118,12 +118,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
// 验证邀请码
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
// 检查类型和状态
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
@@ -134,7 +134,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
// 这是一个配置错误,不应该允许绕过验证
|
||||
if s.emailService == nil {
|
||||
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if verifyCode == "" {
|
||||
@@ -149,7 +149,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
@@ -185,7 +185,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -193,14 +193,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
// 邀请码标记失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
@@ -237,7 +237,7 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
@@ -260,11 +260,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
|
||||
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
@@ -275,17 +275,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
log.Printf("[Auth] Email already exists: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Email already exists: %s", email)
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
|
||||
// 检查邮件队列服务是否配置
|
||||
if s.emailQueueService == nil {
|
||||
log.Println("[Auth] Email queue service not configured")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email queue service not configured")
|
||||
return nil, errors.New("email queue service not configured")
|
||||
}
|
||||
|
||||
@@ -296,13 +296,13 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
}
|
||||
|
||||
// 异步发送
|
||||
log.Printf("[Auth] Enqueueing verify code for: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email)
|
||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err)
|
||||
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Verify code enqueued successfully for: %s", email)
|
||||
return &SendVerifyCodeResult{
|
||||
Countdown: 60, // 60秒倒计时
|
||||
}, nil
|
||||
@@ -314,27 +314,27 @@ func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteI
|
||||
|
||||
if required {
|
||||
if s.settingService == nil {
|
||||
log.Println("[Auth] Turnstile required but settings service is not configured")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but settings service is not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
enabled := s.settingService.IsTurnstileEnabled(ctx)
|
||||
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
|
||||
if !enabled || !secretConfigured {
|
||||
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
}
|
||||
|
||||
if s.turnstileService == nil {
|
||||
if required {
|
||||
log.Println("[Auth] Turnstile required but service not configured")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but service not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
|
||||
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
|
||||
log.Println("[Auth] Turnstile enabled but secret key not configured")
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile enabled but secret key not configured")
|
||||
}
|
||||
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
@@ -373,7 +373,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
// 记录数据库错误但不暴露给用户
|
||||
log.Printf("[Auth] Database error during login: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -426,7 +426,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
@@ -457,18 +457,18 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,7 +523,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
@@ -552,18 +552,18 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
@@ -575,7 +575,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,7 +715,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
log.Printf("[Auth] Database error refreshing token: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error refreshing token: %v", err)
|
||||
return "", ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -756,16 +756,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// Security: Log but don't reveal that user doesn't exist
|
||||
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for non-existent email: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
log.Printf("[Auth] Database error checking email for password reset: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error checking email for password reset: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for inactive user: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
@@ -797,11 +797,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email sent to: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Password reset email sent to: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -821,11 +821,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
|
||||
}
|
||||
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email enqueued for: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Password reset email enqueued for: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -852,7 +852,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return ErrInvalidResetToken // Token was valid but user was deleted
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for password reset: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for password reset: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -872,17 +872,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
|
||||
user.TokenVersion++ // Invalidate all existing tokens
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Also revoke all refresh tokens for this user
|
||||
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
|
||||
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
|
||||
// Don't return error - password was already changed successfully
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -961,13 +961,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// 添加到用户Token集合
|
||||
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to user set: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to user set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
// 添加到家族Token集合
|
||||
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to family set: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to family set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
@@ -994,10 +994,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrRefreshTokenNotFound) {
|
||||
// Token不存在,可能是已被使用(Token轮转)或已过期
|
||||
log.Printf("[Auth] Refresh token not found, possible reuse attack")
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Refresh token not found, possible reuse attack")
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Error getting refresh token: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Error getting refresh token: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -1016,7 +1016,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for token refresh: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@@ -1036,7 +1036,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
|
||||
// Token轮转:立即使旧Token失效
|
||||
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
|
||||
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to delete old refresh token: %v", err)
|
||||
// 继续处理,不影响主流程
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -156,13 +156,13 @@ func (s *BillingCacheService) cacheWriteWorker() {
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri
|
||||
if dropped == 0 {
|
||||
return
|
||||
}
|
||||
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
reason,
|
||||
dropped,
|
||||
cacheWriteDropLogInterval,
|
||||
@@ -274,7 +274,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,7 +302,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +312,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -396,7 +396,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,7 +425,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||||
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +435,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -474,7 +474,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
@@ -496,7 +496,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
@@ -585,7 +585,7 @@ func (b *billingCircuitBreaker) Allow() bool {
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
log.Printf("ALERT: billing circuit breaker entering half-open state")
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
@@ -612,7 +612,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
@@ -620,7 +620,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -641,9 +641,9 @@ func (b *billingCircuitBreaker) OnSuccess() {
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
@@ -124,7 +125,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
@@ -163,7 +164,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
@@ -191,7 +192,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
@@ -209,7 +210,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +222,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco
|
||||
|
||||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
@@ -237,7 +238,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -293,7 +294,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: list schedulable accounts failed: %v", err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
for _, account := range accounts {
|
||||
@@ -301,7 +302,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||
accountCancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -66,7 +66,7 @@ func (s *DashboardAggregationService) Start() {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,9 +82,9 @@ func (s *DashboardAggregationService) Start() {
|
||||
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||
s.runScheduledAggregation()
|
||||
})
|
||||
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
return ErrDashboardBackfillDisabled
|
||||
}
|
||||
if !end.After(start) {
|
||||
@@ -111,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
@@ -142,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errDashboardAggregationRunning) {
|
||||
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start,
|
||||
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
|
||||
start.UTC().Format(time.RFC3339),
|
||||
end.UTC().Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
@@ -199,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
now := time.Now().UTC()
|
||||
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err)
|
||||
last = time.Unix(0, 0).UTC()
|
||||
}
|
||||
|
||||
@@ -217,13 +217,13 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
}
|
||||
|
||||
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
slog.Debug("[DashboardAggregation] 聚合完成",
|
||||
"start", start.Format(time.RFC3339),
|
||||
@@ -262,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start,
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
startUTC.Format(time.RFC3339),
|
||||
endUTC.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
@@ -280,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start,
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err)
|
||||
}
|
||||
return s.repo.AggregateRange(ctx, start, end)
|
||||
}
|
||||
@@ -299,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
}
|
||||
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||
if usageErr != nil {
|
||||
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
@@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return cached, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
@@ -220,12 +220,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u
|
||||
}
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,10 +237,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
}
|
||||
if reason != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T
|
||||
}
|
||||
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err)
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
|
||||
@@ -161,6 +161,9 @@ const (
|
||||
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
||||
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
||||
|
||||
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
|
||||
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
|
||||
|
||||
// =========================
|
||||
// Stream Timeout Handling
|
||||
// =========================
|
||||
|
||||
@@ -3,9 +3,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// Task type constants
|
||||
@@ -56,7 +57,7 @@ func (s *EmailQueueService) start() {
|
||||
s.wg.Add(1)
|
||||
go s.worker(i)
|
||||
}
|
||||
log.Printf("[EmailQueue] Started %d workers", s.workers)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers)
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
@@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) {
|
||||
case task := <-s.taskChan:
|
||||
s.processTask(id, task)
|
||||
case <-s.stopChan:
|
||||
log.Printf("[EmailQueue] Worker %d stopping", id)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
switch task.TaskType {
|
||||
case TaskTypeVerifyCode:
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
case TaskTypePasswordReset:
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
@@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
@@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Println("[EmailQueue] All workers stopped")
|
||||
logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped")
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
|
||||
@@ -62,9 +62,9 @@ func NewErrorPassthroughService(
|
||||
// 启动时加载规则到本地缓存
|
||||
ctx := context.Background()
|
||||
if err := svc.reloadRulesFromDB(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
|
||||
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func NewErrorPassthroughService(
|
||||
if cache != nil {
|
||||
cache.SubscribeUpdates(ctx, func() {
|
||||
if err := svc.refreshLocalCache(context.Background()); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule
|
||||
// 如果本地缓存为空,尝试刷新
|
||||
ctx := context.Background()
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
||||
// 更新 Redis 缓存
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Set(ctx, rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,13 +254,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 先失效缓存,避免后续刷新读到陈旧规则。
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Invalidate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新本地缓存
|
||||
if err := s.reloadRulesFromDB(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
|
||||
s.clearLocalCache()
|
||||
}
|
||||
@@ -268,7 +268,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 通知其他实例
|
||||
if s.cache != nil {
|
||||
if err := s.cache.NotifyUpdate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
|
||||
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
@@ -213,7 +213,7 @@ func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, token
|
||||
if line == "" {
|
||||
return
|
||||
}
|
||||
log.Printf("[ClaudeMimicDebug] %s", line)
|
||||
logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line)
|
||||
}
|
||||
|
||||
func isClaudeCodeCredentialScopeError(msg string) bool {
|
||||
@@ -936,7 +936,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if group != nil {
|
||||
groupPlatform = group.Platform
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
|
||||
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
|
||||
}
|
||||
|
||||
@@ -1006,7 +1006,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
preferOAuth := platform == PlatformGemini
|
||||
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
@@ -1036,7 +1036,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
|
||||
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
|
||||
keys := make([]string, 0, len(group.ModelRouting))
|
||||
@@ -1048,7 +1048,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(keys) > maxKeys {
|
||||
keys = keys[:maxKeys]
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1095,11 +1095,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
|
||||
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
|
||||
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
|
||||
if len(modelScopeSkippedIDs) > 0 {
|
||||
log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, modelScopeSkippedIDs)
|
||||
}
|
||||
}
|
||||
@@ -1124,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
@@ -1217,7 +1217,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
@@ -1234,7 +1234,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
continue // 会话限制已满,尝试下一个
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
@@ -1249,7 +1249,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1510,20 +1510,20 @@ func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupI
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil || group == nil {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Preserve existing behavior: model routing only applies to anthropic groups.
|
||||
if group.Platform != PlatformAnthropic {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ids := group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
|
||||
}
|
||||
return ids
|
||||
@@ -2117,7 +2117,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// so switching model can switch upstream account within the same sticky session.
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
@@ -2134,7 +2134,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
@@ -2209,15 +2209,15 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
@@ -2305,7 +2305,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2324,7 +2324,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
@@ -2342,7 +2342,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
@@ -2418,15 +2418,15 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
@@ -2516,7 +2516,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2831,7 +2831,7 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
|
||||
result, err := sjson.SetBytes(body, "system", newSystem)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to inject Claude Code prompt: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err)
|
||||
return body
|
||||
}
|
||||
return result
|
||||
@@ -2987,7 +2987,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
log.Printf("[Warning] Removed illegal cache_control from thinking block in system")
|
||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3004,7 +3004,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3083,7 +3083,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
@@ -3099,7 +3099,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
|
||||
// 重试循环
|
||||
@@ -3179,7 +3179,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable thinking + thinking->text (preserve content)
|
||||
@@ -3192,7 +3192,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
@@ -3217,7 +3217,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
})
|
||||
msg2 := extractUpstreamErrorMessage(retryRespBody)
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr2 == nil {
|
||||
@@ -3237,9 +3237,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
Kind: "signature_retry_tools_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||
})
|
||||
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
|
||||
} else {
|
||||
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3255,9 +3255,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if retryResp != nil && retryResp.Body != nil {
|
||||
_ = retryResp.Body.Close()
|
||||
}
|
||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
}
|
||||
|
||||
// Retry failed: restore original response body and continue handling.
|
||||
@@ -3303,7 +3303,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
||||
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
return nil, err
|
||||
@@ -3317,9 +3317,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
log.Printf("[DEBUG] %s: %v", k, v)
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||
}
|
||||
}
|
||||
break
|
||||
@@ -3337,7 +3337,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印重试耗尽后的错误响应
|
||||
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
@@ -3368,7 +3368,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
@@ -3422,13 +3422,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
})
|
||||
|
||||
if s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"Account %d: 400 error, attempting failover: %s",
|
||||
account.ID,
|
||||
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
} else {
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
@@ -3497,7 +3497,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err)
|
||||
// 失败时降级为透传原始headers
|
||||
} else {
|
||||
fingerprint = fp
|
||||
@@ -3768,33 +3768,33 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
}
|
||||
|
||||
// Log for debugging
|
||||
log.Printf("[SignatureCheck] Checking error message: %s", msg)
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg)
|
||||
|
||||
// 检测signature相关的错误(更宽松的匹配)
|
||||
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
|
||||
if strings.Contains(msg, "signature") {
|
||||
log.Printf("[SignatureCheck] Detected signature error")
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测 thinking block 顺序/类型错误
|
||||
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
log.Printf("[SignatureCheck] Detected thinking block type error")
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测 thinking block 被修改的错误
|
||||
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
|
||||
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
log.Printf("[SignatureCheck] Detected thinking block modification error")
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
|
||||
// 例如: "all messages must have non-empty content"
|
||||
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
|
||||
log.Printf("[SignatureCheck] Detected empty content error")
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3855,7 +3855,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
@@ -3866,7 +3866,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
|
||||
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
|
||||
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
|
||||
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
|
||||
logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
|
||||
resp.StatusCode,
|
||||
resp.Header.Get("x-request-id"),
|
||||
line,
|
||||
@@ -3906,7 +3906,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
|
||||
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"Upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
@@ -4007,10 +4007,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||
if account.IsOAuth() && statusCode == 403 {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
|
||||
} else {
|
||||
// API Key 未配置错误码:不标记账号状态
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4036,7 +4036,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
|
||||
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
|
||||
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
|
||||
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
|
||||
logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
|
||||
resp.StatusCode,
|
||||
resp.Header.Get("x-request-id"),
|
||||
line,
|
||||
@@ -4065,7 +4065,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
})
|
||||
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
@@ -4325,17 +4325,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if ev.err != nil {
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
@@ -4363,7 +4363,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if !clientDisconnected {
|
||||
if _, werr := fmt.Fprint(w, block); werr != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -4388,10 +4388,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
@@ -4536,7 +4536,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
|
||||
// 用于粘性会话切换时的特殊计费处理
|
||||
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
|
||||
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
||||
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
||||
result.Usage.InputTokens, account.ID)
|
||||
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
|
||||
result.Usage.InputTokens = 0
|
||||
@@ -4597,7 +4597,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
@@ -4668,11 +4668,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
@@ -4684,7 +4684,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
@@ -4693,7 +4693,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
@@ -4703,7 +4703,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
// 更新 API Key 配额(如果设置了配额限制)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update API key quota failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4739,7 +4739,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
|
||||
// 用于粘性会话切换时的特殊计费处理
|
||||
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
|
||||
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
||||
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
||||
result.Usage.InputTokens, account.ID)
|
||||
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
|
||||
result.Usage.InputTokens = 0
|
||||
@@ -4783,7 +4783,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
@@ -4849,11 +4849,11 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
@@ -4865,7 +4865,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
@@ -4874,14 +4874,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
// API Key 独立配额扣费
|
||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Add API key quota used failed: %v", err)
|
||||
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4940,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||
logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4982,7 +4982,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||
@@ -5019,7 +5019,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 记录上游错误摘要便于排障(不回显请求内容)
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
@@ -22,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
|
||||
@@ -282,7 +282,7 @@ func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Contex
|
||||
}
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
@@ -698,7 +698,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: safeErr,
|
||||
})
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -754,7 +754,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
|
||||
if txErr == nil {
|
||||
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
|
||||
geminiReq = retryGeminiReq
|
||||
// Consume one retry budget attempt and continue with the updated request payload.
|
||||
sleepGeminiBackoff(1)
|
||||
@@ -821,7 +821,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -1166,7 +1166,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: safeErr,
|
||||
})
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -1235,7 +1235,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -1367,7 +1367,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
@@ -1544,7 +1544,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
|
||||
})
|
||||
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
}
|
||||
|
||||
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
||||
@@ -2299,13 +2299,13 @@ type UpstreamHTTPResult struct {
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Response Headers ==========")
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ========================================")
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -2339,13 +2339,13 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ====================================================")
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
@@ -2640,16 +2640,16 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
||||
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
|
||||
}
|
||||
ra = time.Now().Add(cooldown)
|
||||
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
|
||||
} else {
|
||||
// API Key / AI Studio OAuth: PST 午夜
|
||||
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||||
ra = time.Unix(*ts, 0)
|
||||
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
|
||||
} else {
|
||||
// 兜底:5 分钟
|
||||
ra = time.Now().Add(5 * time.Minute)
|
||||
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
|
||||
}
|
||||
}
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
@@ -2659,7 +2659,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
||||
// 使用解析到的重置时间
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||||
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
|
||||
account.ID, resetTime, oauthType, tierID)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -328,27 +328,27 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
|
||||
|
||||
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||
func inferGoogleOneTier(storageBytes int64) string {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
|
||||
|
||||
if storageBytes <= 0 {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
if storageBytes > StorageTierUnlimited {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
|
||||
return GeminiTierGoogleAIUltra
|
||||
}
|
||||
if storageBytes >= StorageTierAIPremium {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
|
||||
return GeminiTierGoogleAIPro
|
||||
}
|
||||
if storageBytes >= StorageTierFree {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
|
||||
return GeminiTierGoogleOneFree
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
@@ -358,30 +358,30 @@ func inferGoogleOneTier(storageBytes int64) string {
|
||||
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
|
||||
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
|
||||
|
||||
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
|
||||
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API scope not available (403): %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
// Other errors
|
||||
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Failed to fetch Drive storage: %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
|
||||
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Inferred tier from storage: %s", tierID)
|
||||
|
||||
return tierID, storageInfo, nil
|
||||
}
|
||||
@@ -441,16 +441,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
|
||||
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode START ==========")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] SessionID: %s", input.SessionID)
|
||||
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Session not found or expired")
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
log.Printf("[GeminiOAuth] ERROR: Invalid state")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Invalid state")
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
@@ -461,7 +461,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
@@ -470,8 +470,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Project ID from session: %s", session.ProjectID)
|
||||
|
||||
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
|
||||
if oauthType == "ai_studio" {
|
||||
@@ -496,12 +496,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Token exchange successful")
|
||||
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
|
||||
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token exchange successful")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token scope: %s", tokenResp.Scope)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
|
||||
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
@@ -523,40 +523,40 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection START ==========")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing code_assist OAuth type")
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// Prefer auto-detected tier; fall back to user-selected tier.
|
||||
@@ -564,31 +564,31 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if tierID == "" {
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGCPStandard
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing google_one OAuth type")
|
||||
|
||||
// Google One accounts use cloudaicompanion API, which requires a project_id.
|
||||
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
|
||||
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s", projectID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
var err error
|
||||
@@ -596,12 +596,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
|
||||
tierID = ""
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
|
||||
if storageInfo != nil {
|
||||
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
}
|
||||
@@ -610,10 +610,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGoogleOneFree
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
|
||||
@@ -636,7 +636,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
@@ -649,10 +649,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
}
|
||||
|
||||
default:
|
||||
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection END ==========")
|
||||
|
||||
result := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
@@ -665,8 +665,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END ==========")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -949,23 +949,23 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
registeredTierID := strings.TrimSpace(loadResp.GetTier())
|
||||
if registeredTierID != "" {
|
||||
// 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。
|
||||
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
|
||||
|
||||
// Try to get project from Cloud Resource Manager
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
|
||||
// No project found - user must provide project_id manually
|
||||
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
|
||||
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
|
||||
}
|
||||
}
|
||||
|
||||
// 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser
|
||||
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
@@ -84,7 +85,7 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
@@ -97,10 +98,10 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
|
||||
// 保存到缓存(永不过期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
return fp, nil
|
||||
}
|
||||
|
||||
@@ -277,19 +278,19 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
||||
// 获取或生成固定的伪装 session ID
|
||||
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
|
||||
logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err)
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
if maskedSessionID == "" {
|
||||
// 首次或已过期,生成新的伪装 session ID
|
||||
maskedSessionID = generateRandomUUID()
|
||||
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
|
||||
logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
|
||||
}
|
||||
|
||||
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
|
||||
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
|
||||
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
||||
@@ -335,7 +336,7 @@ func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// 极罕见的情况,使用时间戳+固定值作为fallback
|
||||
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
|
||||
logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err)
|
||||
// 使用SHA256(当前纳秒时间)作为fallback
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
return hex.EncodeToString(h[:])
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -19,12 +18,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -786,7 +787,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
@@ -795,7 +796,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
if model, ok := reqBody["model"].(string); ok {
|
||||
normalizedModel := normalizeCodexModel(model)
|
||||
if normalizedModel != "" && normalizedModel != model {
|
||||
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
||||
reqBody["model"] = normalizedModel
|
||||
mappedModel = normalizedModel
|
||||
@@ -808,7 +809,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
||||
reasoning["effort"] = "none"
|
||||
bodyModified = true
|
||||
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1012,7 +1013,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
reqStream bool,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
account.Name,
|
||||
@@ -1022,18 +1023,15 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
)
|
||||
if reqStream && c != nil && c.Request != nil {
|
||||
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
|
||||
streamWarnLogger := logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Strings("timeout_headers", timeoutHeaders),
|
||||
)
|
||||
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流: account=%d headers=%s",
|
||||
account.ID,
|
||||
strings.Join(timeoutHeaders, ", "),
|
||||
)
|
||||
streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流")
|
||||
} else {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 检测到超时相关请求头,将按配置过滤以降低断流风险: account=%d headers=%s",
|
||||
account.ID,
|
||||
strings.Join(timeoutHeaders, ", "),
|
||||
)
|
||||
streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1347,7 +1345,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -1355,11 +1353,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
log.Printf("[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[WARN] [OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
@@ -1369,10 +1367,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[WARN] [OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
@@ -1381,11 +1379,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 上游流在未收到 [DONE] 时结束,疑似断流: account=%d request_id=%s",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
)
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Warn("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
@@ -1584,7 +1582,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
@@ -1844,16 +1842,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
@@ -1882,7 +1880,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -1899,7 +1897,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -1912,10 +1910,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
@@ -1932,7 +1930,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -2323,7 +2321,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
@@ -2346,7 +2344,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Update API key quota if applicable (only for balance mode with quota set)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update API key quota failed: %v", err)
|
||||
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,17 +3,17 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -46,24 +46,76 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
var stdLogCaptureMu sync.Mutex
|
||||
var structuredLogCaptureMu sync.Mutex
|
||||
|
||||
func captureStdLog(t *testing.T) (*bytes.Buffer, func()) {
|
||||
t.Helper()
|
||||
stdLogCaptureMu.Lock()
|
||||
buf := &bytes.Buffer{}
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetFlags(0)
|
||||
log.SetOutput(buf)
|
||||
return buf, func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
// 防御性恢复,避免其他测试改动了底层 writer。
|
||||
if prevWriter == nil {
|
||||
log.SetOutput(os.Stderr)
|
||||
type inMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
stdLogCaptureMu.Unlock()
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *inMemoryLogSink) ContainsMessage(substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev != nil && strings.Contains(ev.Message, substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
structuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &inMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
structuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,7 +538,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logBuf, restore := captureStdLog(t)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -521,13 +573,13 @@ func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, logBuf.String(), "检测到超时相关请求头,将按配置过滤以降低断流风险")
|
||||
require.Contains(t, logBuf.String(), "x-stainless-timeout=10000")
|
||||
require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险"))
|
||||
require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_WarnWhenStreamEndsWithoutDone(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logBuf, restore := captureStdLog(t)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -562,8 +614,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_WarnWhenStreamEndsWithoutDone(t *
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, logBuf.String(), "上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
require.Contains(t, logBuf.String(), "rid-truncate")
|
||||
require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) {
|
||||
|
||||
@@ -3,8 +3,9 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
@@ -140,7 +141,7 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
return data, false
|
||||
}
|
||||
|
||||
@@ -219,13 +220,13 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
argsMap["workdir"] = workDir
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
}
|
||||
} else {
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,17 +237,17 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file_path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["file"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
argsMap["oldString"] = oldString
|
||||
delete(argsMap, "old_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +265,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
argsMap["newString"] = newString
|
||||
delete(argsMap, "new_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +274,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
argsMap["replaceAll"] = replaceAll
|
||||
delete(argsMap, "replace_all")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -303,7 +304,7 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||
key := fmt.Sprintf("%s->%s", from, to)
|
||||
c.stats.CorrectionsByTool[key]++
|
||||
|
||||
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||
from, to, c.stats.TotalCorrected)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@@ -190,7 +190,7 @@ func (s *OpsAggregationService) aggregateHourly() {
|
||||
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
|
||||
logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggHourlyOverlap)
|
||||
if candidate.After(start) {
|
||||
@@ -209,7 +209,7 @@ func (s *OpsAggregationService) aggregateHourly() {
|
||||
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
|
||||
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
|
||||
logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -288,7 +288,7 @@ func (s *OpsAggregationService) aggregateDaily() {
|
||||
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
|
||||
logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggDailyOverlap)
|
||||
if candidate.After(start) {
|
||||
@@ -307,7 +307,7 @@ func (s *OpsAggregationService) aggregateDaily() {
|
||||
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
|
||||
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
|
||||
logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -427,7 +427,7 @@ func (s *OpsAggregationService) maybeLogSkip(prefix string) {
|
||||
if prefix == "" {
|
||||
prefix = "[OpsAggregation]"
|
||||
}
|
||||
log.Printf("%s leader lock held by another instance; skipping", prefix)
|
||||
logger.LegacyPrintf("service.ops_aggregation", "%s leader lock held by another instance; skipping", prefix)
|
||||
}
|
||||
|
||||
func utcFloorToHour(t time.Time) time.Time {
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@@ -186,7 +186,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
rules, err := s.opsRepo.ListAlertRules(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] list rules failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
|
||||
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
|
||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
if latestEvent != nil && rule.CooldownMinutes > 0 {
|
||||
@@ -283,7 +283,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
|
||||
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
if activeEvent != nil {
|
||||
resolvedAt := now
|
||||
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||
} else {
|
||||
eventsResolved++
|
||||
}
|
||||
@@ -779,7 +779,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc
|
||||
}
|
||||
if s.redisClient == nil {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] redis not configured; running without distributed lock")
|
||||
})
|
||||
return nil, true
|
||||
}
|
||||
@@ -797,7 +797,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc
|
||||
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
|
||||
// Single-node deployments can disable the distributed lock via runtime settings.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
@@ -819,7 +819,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
||||
return
|
||||
}
|
||||
s.skipLogAt = now
|
||||
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||
logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/robfig/cron/v3"
|
||||
@@ -75,11 +75,11 @@ func (s *OpsCleanupService) Start() {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled {
|
||||
log.Printf("[OpsCleanup] not started (disabled)")
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.opsRepo == nil || s.db == nil {
|
||||
log.Printf("[OpsCleanup] not started (missing deps)")
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -99,12 +99,12 @@ func (s *OpsCleanupService) Start() {
|
||||
c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc))
|
||||
_, err := c.AddFunc(schedule, func() { s.runScheduled() })
|
||||
if err != nil {
|
||||
log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ func (s *OpsCleanupService) Stop() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Printf("[OpsCleanup] cron stop timed out")
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -146,17 +146,19 @@ func (s *OpsCleanupService) runScheduled() {
|
||||
counts, err := s.runCleanupOnce(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup failed: %v", err)
|
||||
return
|
||||
}
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
|
||||
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup complete: %s", counts)
|
||||
}
|
||||
|
||||
type opsCleanupDeletedCounts struct {
|
||||
errorLogs int64
|
||||
retryAttempts int64
|
||||
alertEvents int64
|
||||
systemLogs int64
|
||||
logAudits int64
|
||||
systemMetrics int64
|
||||
hourlyPreagg int64
|
||||
dailyPreagg int64
|
||||
@@ -164,10 +166,12 @@ type opsCleanupDeletedCounts struct {
|
||||
|
||||
func (c opsCleanupDeletedCounts) String() string {
|
||||
return fmt.Sprintf(
|
||||
"error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
|
||||
"error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
|
||||
c.errorLogs,
|
||||
c.retryAttempts,
|
||||
c.alertEvents,
|
||||
c.systemLogs,
|
||||
c.logAudits,
|
||||
c.systemMetrics,
|
||||
c.hourlyPreagg,
|
||||
c.dailyPreagg,
|
||||
@@ -204,6 +208,18 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
||||
return out, err
|
||||
}
|
||||
out.alertEvents = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.systemLogs = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.logAudits = n
|
||||
}
|
||||
|
||||
// Minute-level metrics snapshots.
|
||||
@@ -315,11 +331,11 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
|
||||
}
|
||||
// Redis error: fall back to DB advisory lock.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err)
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err)
|
||||
})
|
||||
} else {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] redis not configured; using DB advisory lock")
|
||||
logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] redis not configured; using DB advisory lock")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
267
backend/internal/service/ops_log_runtime.go
Normal file
267
backend/internal/service/ops_log_runtime.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func defaultOpsRuntimeLogConfig(cfg *config.Config) *OpsRuntimeLogConfig {
|
||||
out := &OpsRuntimeLogConfig{
|
||||
Level: "info",
|
||||
EnableSampling: false,
|
||||
SamplingInitial: 100,
|
||||
SamplingNext: 100,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 30,
|
||||
}
|
||||
if cfg == nil {
|
||||
return out
|
||||
}
|
||||
out.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
|
||||
out.EnableSampling = cfg.Log.Sampling.Enabled
|
||||
out.SamplingInitial = cfg.Log.Sampling.Initial
|
||||
out.SamplingNext = cfg.Log.Sampling.Thereafter
|
||||
out.Caller = cfg.Log.Caller
|
||||
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||||
if cfg.Ops.Cleanup.ErrorLogRetentionDays > 0 {
|
||||
out.RetentionDays = cfg.Ops.Cleanup.ErrorLogRetentionDays
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig, defaults *OpsRuntimeLogConfig) {
|
||||
if cfg == nil || defaults == nil {
|
||||
return
|
||||
}
|
||||
cfg.Level = strings.ToLower(strings.TrimSpace(cfg.Level))
|
||||
if cfg.Level == "" {
|
||||
cfg.Level = defaults.Level
|
||||
}
|
||||
cfg.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel))
|
||||
if cfg.StacktraceLevel == "" {
|
||||
cfg.StacktraceLevel = defaults.StacktraceLevel
|
||||
}
|
||||
if cfg.SamplingInitial <= 0 {
|
||||
cfg.SamplingInitial = defaults.SamplingInitial
|
||||
}
|
||||
if cfg.SamplingNext <= 0 {
|
||||
cfg.SamplingNext = defaults.SamplingNext
|
||||
}
|
||||
if cfg.RetentionDays <= 0 {
|
||||
cfg.RetentionDays = defaults.RetentionDays
|
||||
}
|
||||
}
|
||||
|
||||
func validateOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error {
|
||||
if cfg == nil {
|
||||
return errors.New("invalid config")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.Level)) {
|
||||
case "debug", "info", "warn", "error":
|
||||
default:
|
||||
return errors.New("level must be one of: debug/info/warn/error")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) {
|
||||
case "none", "error", "fatal":
|
||||
default:
|
||||
return errors.New("stacktrace_level must be one of: none/error/fatal")
|
||||
}
|
||||
if cfg.SamplingInitial <= 0 {
|
||||
return errors.New("sampling_initial must be positive")
|
||||
}
|
||||
if cfg.SamplingNext <= 0 {
|
||||
return errors.New("sampling_thereafter must be positive")
|
||||
}
|
||||
if cfg.RetentionDays < 1 || cfg.RetentionDays > 3650 {
|
||||
return errors.New("retention_days must be between 1 and 3650")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetRuntimeLogConfig(ctx context.Context) (*OpsRuntimeLogConfig, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
var cfg *config.Config
|
||||
if s != nil {
|
||||
cfg = s.cfg
|
||||
}
|
||||
defaultCfg := defaultOpsRuntimeLogConfig(cfg)
|
||||
return defaultCfg, nil
|
||||
}
|
||||
defaultCfg := defaultOpsRuntimeLogConfig(s.cfg)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRuntimeLogConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
b, _ := json.Marshal(defaultCfg)
|
||||
_ = s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(b))
|
||||
return defaultCfg, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsRuntimeLogConfig{}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
normalizeOpsRuntimeLogConfig(cfg, defaultCfg)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateRuntimeLogConfig(ctx context.Context, req *OpsRuntimeLogConfig, operatorID int64) (*OpsRuntimeLogConfig, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, errors.New("invalid config")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if operatorID <= 0 {
|
||||
return nil, errors.New("invalid operator id")
|
||||
}
|
||||
|
||||
oldCfg, err := s.GetRuntimeLogConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next := *req
|
||||
normalizeOpsRuntimeLogConfig(&next, defaultOpsRuntimeLogConfig(s.cfg))
|
||||
if err := validateOpsRuntimeLogConfig(&next); err != nil {
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "validation_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := applyOpsRuntimeLogConfig(&next); err != nil {
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "apply_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
next.Source = "runtime_setting"
|
||||
next.UpdatedAt = time.Now().UTC().Format(time.RFC3339Nano)
|
||||
next.UpdatedByUserID = operatorID
|
||||
|
||||
encoded, err := json.Marshal(&next)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(encoded)); err != nil {
|
||||
// 存储失败时回滚到旧配置,避免内存状态与持久化状态不一致。
|
||||
_ = applyOpsRuntimeLogConfig(oldCfg)
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "persist_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditRuntimeLogConfigChange(operatorID, oldCfg, &next, "updated")
|
||||
|
||||
return &next, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) ResetRuntimeLogConfig(ctx context.Context, operatorID int64) (*OpsRuntimeLogConfig, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if operatorID <= 0 {
|
||||
return nil, errors.New("invalid operator id")
|
||||
}
|
||||
|
||||
oldCfg, err := s.GetRuntimeLogConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resetCfg := defaultOpsRuntimeLogConfig(s.cfg)
|
||||
normalizeOpsRuntimeLogConfig(resetCfg, defaultOpsRuntimeLogConfig(s.cfg))
|
||||
if err := validateOpsRuntimeLogConfig(resetCfg); err != nil {
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_validation_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if err := applyOpsRuntimeLogConfig(resetCfg); err != nil {
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_apply_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 清理 runtime 覆盖配置,回退到 env/yaml baseline。
|
||||
if err := s.settingRepo.Delete(ctx, SettingKeyOpsRuntimeLogConfig); err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
_ = applyOpsRuntimeLogConfig(oldCfg)
|
||||
s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_persist_failed: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
resetCfg.Source = "baseline"
|
||||
resetCfg.UpdatedAt = now
|
||||
resetCfg.UpdatedByUserID = operatorID
|
||||
|
||||
s.auditRuntimeLogConfigChange(operatorID, oldCfg, resetCfg, "reset")
|
||||
return resetCfg, nil
|
||||
}
|
||||
|
||||
func applyOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("nil runtime log config")
|
||||
}
|
||||
if err := logger.Reconfigure(func(opts *logger.InitOptions) error {
|
||||
opts.Level = strings.ToLower(strings.TrimSpace(cfg.Level))
|
||||
opts.Caller = cfg.Caller
|
||||
opts.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel))
|
||||
opts.Sampling.Enabled = cfg.EnableSampling
|
||||
opts.Sampling.Initial = cfg.SamplingInitial
|
||||
opts.Sampling.Thereafter = cfg.SamplingNext
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) applyRuntimeLogConfigOnStartup(ctx context.Context) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
cfg, err := s.GetRuntimeLogConfig(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = applyOpsRuntimeLogConfig(cfg)
|
||||
}
|
||||
|
||||
func (s *OpsService) auditRuntimeLogConfigChange(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, action string) {
|
||||
oldRaw, _ := json.Marshal(oldCfg)
|
||||
newRaw, _ := json.Marshal(newCfg)
|
||||
logger.With(
|
||||
zap.String("component", "audit.log_config_change"),
|
||||
zap.String("action", strings.TrimSpace(action)),
|
||||
zap.Int64("operator_id", operatorID),
|
||||
zap.String("old", string(oldRaw)),
|
||||
zap.String("new", string(newRaw)),
|
||||
).Info("runtime log config changed")
|
||||
}
|
||||
|
||||
func (s *OpsService) auditRuntimeLogConfigFailure(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, reason string) {
|
||||
oldRaw, _ := json.Marshal(oldCfg)
|
||||
newRaw, _ := json.Marshal(newCfg)
|
||||
logger.With(
|
||||
zap.String("component", "audit.log_config_change"),
|
||||
zap.String("action", "failed"),
|
||||
zap.Int64("operator_id", operatorID),
|
||||
zap.String("reason", strings.TrimSpace(reason)),
|
||||
zap.String("old", string(oldRaw)),
|
||||
zap.String("new", string(newRaw)),
|
||||
).Warn("runtime log config change failed")
|
||||
}
|
||||
570
backend/internal/service/ops_log_runtime_test.go
Normal file
570
backend/internal/service/ops_log_runtime_test.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
type runtimeSettingRepoStub struct {
|
||||
values map[string]string
|
||||
deleted map[string]bool
|
||||
setCalls int
|
||||
getValueFn func(key string) (string, error)
|
||||
setFn func(key, value string) error
|
||||
deleteFn func(key string) error
|
||||
}
|
||||
|
||||
func newRuntimeSettingRepoStub() *runtimeSettingRepoStub {
|
||||
return &runtimeSettingRepoStub{
|
||||
values: map[string]string{},
|
||||
deleted: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
value, err := s.GetValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
if s.getValueFn != nil {
|
||||
return s.getValueFn(key)
|
||||
}
|
||||
value, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) Set(_ context.Context, key, value string) error {
|
||||
if s.setFn != nil {
|
||||
if err := s.setFn(key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.values[key] = value
|
||||
s.setCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
for key, value := range settings {
|
||||
s.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) GetAll(_ context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(s.values))
|
||||
for key, value := range s.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *runtimeSettingRepoStub) Delete(_ context.Context, key string) error {
|
||||
if s.deleteFn != nil {
|
||||
if err := s.deleteFn(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, ok := s.values[key]; !ok {
|
||||
return ErrSettingNotFound
|
||||
}
|
||||
delete(s.values, key)
|
||||
s.deleted[key] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateRuntimeLogConfig_InvalidConfigShouldNotApply(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
_, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{
|
||||
Level: "trace",
|
||||
EnableSampling: true,
|
||||
SamplingInitial: 100,
|
||||
SamplingNext: 100,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 30,
|
||||
}, 1)
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error")
|
||||
}
|
||||
if logger.CurrentLevel() != "info" {
|
||||
t.Fatalf("logger level changed unexpectedly: %s", logger.CurrentLevel())
|
||||
}
|
||||
if repo.setCalls != 1 {
|
||||
// GetRuntimeLogConfig() 会在 key 缺失时写入默认值,此处应只有这一次持久化。
|
||||
t.Fatalf("unexpected set calls: %d", repo.setCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetRuntimeLogConfig_ShouldFallbackToBaseline(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
existing := &OpsRuntimeLogConfig{
|
||||
Level: "debug",
|
||||
EnableSampling: true,
|
||||
SamplingInitial: 50,
|
||||
SamplingNext: 50,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 60,
|
||||
Source: "runtime_setting",
|
||||
}
|
||||
raw, _ := json.Marshal(existing)
|
||||
repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw)
|
||||
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "warn",
|
||||
Caller: false,
|
||||
StacktraceLevel: "fatal",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
Ops: config.OpsConfig{
|
||||
Cleanup: config.OpsCleanupConfig{
|
||||
ErrorLogRetentionDays: 45,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
resetCfg, err := svc.ResetRuntimeLogConfig(context.Background(), 9)
|
||||
if err != nil {
|
||||
t.Fatalf("ResetRuntimeLogConfig() error: %v", err)
|
||||
}
|
||||
if resetCfg.Source != "baseline" {
|
||||
t.Fatalf("source = %q, want baseline", resetCfg.Source)
|
||||
}
|
||||
if resetCfg.Level != "warn" {
|
||||
t.Fatalf("level = %q, want warn", resetCfg.Level)
|
||||
}
|
||||
if resetCfg.RetentionDays != 45 {
|
||||
t.Fatalf("retention_days = %d, want 45", resetCfg.RetentionDays)
|
||||
}
|
||||
if logger.CurrentLevel() != "warn" {
|
||||
t.Fatalf("logger level = %q, want warn", logger.CurrentLevel())
|
||||
}
|
||||
if !repo.deleted[SettingKeyOpsRuntimeLogConfig] {
|
||||
t.Fatalf("runtime setting key should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetRuntimeLogConfig_InvalidOperator(t *testing.T) {
|
||||
svc := &OpsService{settingRepo: newRuntimeSettingRepoStub()}
|
||||
_, err := svc.ResetRuntimeLogConfig(context.Background(), 0)
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid operator error")
|
||||
}
|
||||
if err.Error() != "invalid operator id" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRuntimeLogConfig_InvalidJSONFallback(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
repo.values[SettingKeyOpsRuntimeLogConfig] = `{invalid-json}`
|
||||
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "warn",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, err := svc.GetRuntimeLogConfig(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetRuntimeLogConfig() error: %v", err)
|
||||
}
|
||||
if got.Level != "warn" {
|
||||
t.Fatalf("level = %q, want warn", got.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRuntimeLogConfig_PersistFailureRollback(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
oldCfg := &OpsRuntimeLogConfig{
|
||||
Level: "info",
|
||||
EnableSampling: false,
|
||||
SamplingInitial: 100,
|
||||
SamplingNext: 100,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 30,
|
||||
}
|
||||
raw, _ := json.Marshal(oldCfg)
|
||||
repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw)
|
||||
repo.setFn = func(key, value string) error {
|
||||
if key == SettingKeyOpsRuntimeLogConfig {
|
||||
return errors.New("db down")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
_, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{
|
||||
Level: "debug",
|
||||
EnableSampling: false,
|
||||
SamplingInitial: 100,
|
||||
SamplingNext: 100,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 30,
|
||||
}, 5)
|
||||
if err == nil {
|
||||
t.Fatalf("expected persist error")
|
||||
}
|
||||
// Persist failure should rollback runtime level back to old effective level.
|
||||
if logger.CurrentLevel() != "info" {
|
||||
t.Fatalf("logger level should rollback to info, got %s", logger.CurrentLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRuntimeLogConfigOnStartup(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
cfgRaw := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||
repo.values[SettingKeyOpsRuntimeLogConfig] = cfgRaw
|
||||
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
svc.applyRuntimeLogConfigOnStartup(context.Background())
|
||||
if logger.CurrentLevel() != "debug" {
|
||||
t.Fatalf("expected startup apply debug, got %s", logger.CurrentLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultNormalizeAndValidateRuntimeLogConfig(t *testing.T) {
|
||||
defaults := defaultOpsRuntimeLogConfig(&config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "DEBUG",
|
||||
Caller: false,
|
||||
StacktraceLevel: "FATAL",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: true,
|
||||
Initial: 50,
|
||||
Thereafter: 20,
|
||||
},
|
||||
},
|
||||
Ops: config.OpsConfig{
|
||||
Cleanup: config.OpsCleanupConfig{
|
||||
ErrorLogRetentionDays: 7,
|
||||
},
|
||||
},
|
||||
})
|
||||
if defaults.Level != "debug" || defaults.StacktraceLevel != "fatal" || defaults.RetentionDays != 7 {
|
||||
t.Fatalf("unexpected defaults: %+v", defaults)
|
||||
}
|
||||
|
||||
cfg := &OpsRuntimeLogConfig{
|
||||
Level: " ",
|
||||
EnableSampling: true,
|
||||
SamplingInitial: 0,
|
||||
SamplingNext: -1,
|
||||
Caller: true,
|
||||
StacktraceLevel: "",
|
||||
RetentionDays: 0,
|
||||
}
|
||||
normalizeOpsRuntimeLogConfig(cfg, defaults)
|
||||
if cfg.Level != "debug" || cfg.StacktraceLevel != "fatal" {
|
||||
t.Fatalf("normalize level/stacktrace failed: %+v", cfg)
|
||||
}
|
||||
if cfg.SamplingInitial != 50 || cfg.SamplingNext != 20 || cfg.RetentionDays != 7 {
|
||||
t.Fatalf("normalize numeric defaults failed: %+v", cfg)
|
||||
}
|
||||
if err := validateOpsRuntimeLogConfig(cfg); err != nil {
|
||||
t.Fatalf("validate normalized config should pass: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRuntimeLogConfigErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cfg *OpsRuntimeLogConfig
|
||||
}{
|
||||
{name: "nil", cfg: nil},
|
||||
{name: "bad level", cfg: &OpsRuntimeLogConfig{Level: "trace", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}},
|
||||
{name: "bad stack", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "warn", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}},
|
||||
{name: "bad initial", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 0, SamplingNext: 1, RetentionDays: 1}},
|
||||
{name: "bad next", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 0, RetentionDays: 1}},
|
||||
{name: "bad retention", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 0}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if err := validateOpsRuntimeLogConfig(tc.cfg); err == nil {
|
||||
t.Fatalf("expected validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRuntimeLogConfigFallbackAndErrors(t *testing.T) {
|
||||
var nilSvc *OpsService
|
||||
cfg, err := nilSvc.GetRuntimeLogConfig(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("nil svc should fallback default: %v", err)
|
||||
}
|
||||
if cfg.Level != "info" {
|
||||
t.Fatalf("unexpected nil svc default level: %s", cfg.Level)
|
||||
}
|
||||
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
repo.getValueFn = func(key string) (string, error) {
|
||||
return "", errors.New("boom")
|
||||
}
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "warn",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, err := svc.GetRuntimeLogConfig(context.Background()); err == nil {
|
||||
t.Fatalf("expected get value error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRuntimeLogConfig_PreconditionErrors(t *testing.T) {
|
||||
svc := &OpsService{}
|
||||
if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{}, 1); err == nil {
|
||||
t.Fatalf("expected setting repo not initialized")
|
||||
}
|
||||
|
||||
svc = &OpsService{settingRepo: newRuntimeSettingRepoStub()}
|
||||
if _, err := svc.UpdateRuntimeLogConfig(context.Background(), nil, 1); err == nil {
|
||||
t.Fatalf("expected invalid config")
|
||||
}
|
||||
if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{
|
||||
Level: "info",
|
||||
StacktraceLevel: "error",
|
||||
SamplingInitial: 1,
|
||||
SamplingNext: 1,
|
||||
RetentionDays: 1,
|
||||
}, 0); err == nil {
|
||||
t.Fatalf("expected invalid operator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRuntimeLogConfig_Success(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
next, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{
|
||||
Level: "debug",
|
||||
EnableSampling: false,
|
||||
SamplingInitial: 100,
|
||||
SamplingNext: 100,
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
RetentionDays: 30,
|
||||
}, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateRuntimeLogConfig() error: %v", err)
|
||||
}
|
||||
if next.Source != "runtime_setting" || next.UpdatedByUserID != 2 || next.UpdatedAt == "" {
|
||||
t.Fatalf("unexpected metadata: %+v", next)
|
||||
}
|
||||
if logger.CurrentLevel() != "debug" {
|
||||
t.Fatalf("expected applied level debug, got %s", logger.CurrentLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetRuntimeLogConfig_IgnoreNotFoundDelete(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
repo.deleteFn = func(key string) error { return ErrSettingNotFound }
|
||||
svc := &OpsService{
|
||||
settingRepo: repo,
|
||||
cfg: &config.Config{
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, err := svc.ResetRuntimeLogConfig(context.Background(), 1); err != nil {
|
||||
t.Fatalf("reset should ignore ErrSettingNotFound: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRuntimeLogConfigHelpers(t *testing.T) {
|
||||
if err := applyOpsRuntimeLogConfig(nil); err == nil {
|
||||
t.Fatalf("expected nil config error")
|
||||
}
|
||||
|
||||
normalizeOpsRuntimeLogConfig(nil, &OpsRuntimeLogConfig{Level: "info"})
|
||||
normalizeOpsRuntimeLogConfig(&OpsRuntimeLogConfig{Level: "debug"}, nil)
|
||||
|
||||
var nilSvc *OpsService
|
||||
nilSvc.applyRuntimeLogConfigOnStartup(context.Background())
|
||||
}
|
||||
@@ -2,6 +2,21 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsSystemLog struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Level string `json:"level"`
|
||||
Component string `json:"component"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
type OpsErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -10,6 +10,10 @@ type OpsRepository interface {
|
||||
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
||||
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
||||
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
||||
BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
||||
ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
||||
DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
||||
InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error
|
||||
|
||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||
@@ -205,6 +209,69 @@ type OpsInsertSystemMetricsInput struct {
|
||||
ConcurrencyQueueDepth *int
|
||||
}
|
||||
|
||||
type OpsInsertSystemLogInput struct {
|
||||
CreatedAt time.Time
|
||||
Level string
|
||||
Component string
|
||||
Message string
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
UserID *int64
|
||||
AccountID *int64
|
||||
Platform string
|
||||
Model string
|
||||
ExtraJSON string
|
||||
}
|
||||
|
||||
type OpsSystemLogFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
Level string
|
||||
Component string
|
||||
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
UserID *int64
|
||||
AccountID *int64
|
||||
Platform string
|
||||
Model string
|
||||
Query string
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
type OpsSystemLogCleanupFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
Level string
|
||||
Component string
|
||||
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
UserID *int64
|
||||
AccountID *int64
|
||||
Platform string
|
||||
Model string
|
||||
Query string
|
||||
}
|
||||
|
||||
type OpsSystemLogList struct {
|
||||
Logs []*OpsSystemLog `json:"logs"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type OpsSystemLogCleanupAudit struct {
|
||||
CreatedAt time.Time
|
||||
OperatorID int64
|
||||
Conditions string
|
||||
DeletedRows int64
|
||||
}
|
||||
|
||||
type OpsSystemMetricsSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
196
backend/internal/service/ops_repo_mock_test.go
Normal file
196
backend/internal/service/ops_repo_mock_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
|
||||
type opsRepoMock struct {
|
||||
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
|
||||
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
|
||||
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
|
||||
InsertSystemLogCleanupAuditFn func(ctx context.Context, input *OpsSystemLogCleanupAudit) error
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
||||
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
|
||||
return &OpsErrorLogDetail{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) {
|
||||
return []*OpsRequestDetail{}, 0, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) {
|
||||
if m.BatchInsertSystemLogsFn != nil {
|
||||
return m.BatchInsertSystemLogsFn(ctx, inputs)
|
||||
}
|
||||
return int64(len(inputs)), nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) {
|
||||
if m.ListSystemLogsFn != nil {
|
||||
return m.ListSystemLogsFn(ctx, filter)
|
||||
}
|
||||
return &OpsSystemLogList{Logs: []*OpsSystemLog{}, Total: 0, Page: 1, PageSize: 50}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
if m.DeleteSystemLogsFn != nil {
|
||||
return m.DeleteSystemLogsFn(ctx, filter)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error {
|
||||
if m.InsertSystemLogCleanupAuditFn != nil {
|
||||
return m.InsertSystemLogCleanupAuditFn(ctx, input)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error) {
|
||||
return []*OpsRetryAttempt{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error) {
|
||||
return &OpsWindowStats{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) {
|
||||
return &OpsRealtimeTrafficSummary{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||
return &OpsDashboardOverview{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) {
|
||||
return &OpsThroughputTrendResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) {
|
||||
return &OpsLatencyHistogramResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) {
|
||||
return &OpsErrorTrendResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
|
||||
return &OpsErrorDistributionResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) {
|
||||
return &OpsOpenAITokenStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) {
|
||||
return &OpsSystemMetricsSnapshot{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error) {
|
||||
return []*OpsJobHeartbeat{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) {
|
||||
return []*OpsAlertRule{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) {
|
||||
return []*OpsAlertEvent{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||
return &OpsAlertEvent{}, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) {
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
|
||||
func (m *opsRepoMock) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
|
||||
var _ OpsRepository = (*opsRepoMock)(nil)
|
||||
@@ -37,6 +37,7 @@ type OpsService struct {
|
||||
openAIGatewayService *OpenAIGatewayService
|
||||
geminiCompatService *GeminiMessagesCompatService
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
systemLogSink *OpsSystemLogSink
|
||||
}
|
||||
|
||||
func NewOpsService(
|
||||
@@ -50,8 +51,9 @@ func NewOpsService(
|
||||
openAIGatewayService *OpenAIGatewayService,
|
||||
geminiCompatService *GeminiMessagesCompatService,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
systemLogSink *OpsSystemLogSink,
|
||||
) *OpsService {
|
||||
return &OpsService{
|
||||
svc := &OpsService{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
@@ -64,7 +66,10 @@ func NewOpsService(
|
||||
openAIGatewayService: openAIGatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
systemLogSink: systemLogSink,
|
||||
}
|
||||
svc.applyRuntimeLogConfigOnStartup(context.Background())
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error {
|
||||
|
||||
@@ -68,6 +68,20 @@ type OpsMetricThresholds struct {
|
||||
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
|
||||
}
|
||||
|
||||
type OpsRuntimeLogConfig struct {
|
||||
Level string `json:"level"`
|
||||
EnableSampling bool `json:"enable_sampling"`
|
||||
SamplingInitial int `json:"sampling_initial"`
|
||||
SamplingNext int `json:"sampling_thereafter"`
|
||||
Caller bool `json:"caller"`
|
||||
StacktraceLevel string `json:"stacktrace_level"`
|
||||
RetentionDays int `json:"retention_days"`
|
||||
Source string `json:"source,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
UpdatedByUserID int64 `json:"updated_by_user_id,omitempty"`
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
type OpsAlertRuntimeSettings struct {
|
||||
EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"`
|
||||
|
||||
|
||||
124
backend/internal/service/ops_system_log_service.go
Normal file
124
backend/internal/service/ops_system_log_service.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return &OpsSystemLogList{
|
||||
Logs: []*OpsSystemLog{},
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 50,
|
||||
}, nil
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &OpsSystemLogFilter{}
|
||||
}
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 50
|
||||
}
|
||||
if filter.PageSize > 200 {
|
||||
filter.PageSize = 200
|
||||
}
|
||||
|
||||
result, err := s.opsRepo.ListSystemLogs(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("OPS_SYSTEM_LOG_LIST_FAILED", "Failed to list system logs").WithCause(err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) CleanupSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter, operatorID int64) (int64, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return 0, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if operatorID <= 0 {
|
||||
return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_OPERATOR", "invalid operator")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &OpsSystemLogCleanupFilter{}
|
||||
}
|
||||
if filter.EndTime != nil && filter.StartTime != nil && filter.StartTime.After(*filter.EndTime) {
|
||||
return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_RANGE", "invalid time range")
|
||||
}
|
||||
|
||||
deletedRows, err := s.opsRepo.DeleteSystemLogs(ctx, filter)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
}
|
||||
if strings.Contains(strings.ToLower(err.Error()), "requires at least one filter") {
|
||||
return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_FILTER_REQUIRED", "cleanup requires at least one filter condition")
|
||||
}
|
||||
return 0, infraerrors.InternalServer("OPS_SYSTEM_LOG_CLEANUP_FAILED", "Failed to cleanup system logs").WithCause(err)
|
||||
}
|
||||
|
||||
if auditErr := s.opsRepo.InsertSystemLogCleanupAudit(ctx, &OpsSystemLogCleanupAudit{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
OperatorID: operatorID,
|
||||
Conditions: marshalSystemLogCleanupConditions(filter),
|
||||
DeletedRows: deletedRows,
|
||||
}); auditErr != nil {
|
||||
// 审计失败不影响主流程,避免运维清理被阻塞。
|
||||
log.Printf("[OpsSystemLog] cleanup audit failed: %v", auditErr)
|
||||
}
|
||||
return deletedRows, nil
|
||||
}
|
||||
|
||||
func marshalSystemLogCleanupConditions(filter *OpsSystemLogCleanupFilter) string {
|
||||
if filter == nil {
|
||||
return "{}"
|
||||
}
|
||||
payload := map[string]any{
|
||||
"level": strings.TrimSpace(filter.Level),
|
||||
"component": strings.TrimSpace(filter.Component),
|
||||
"request_id": strings.TrimSpace(filter.RequestID),
|
||||
"client_request_id": strings.TrimSpace(filter.ClientRequestID),
|
||||
"platform": strings.TrimSpace(filter.Platform),
|
||||
"model": strings.TrimSpace(filter.Model),
|
||||
"query": strings.TrimSpace(filter.Query),
|
||||
}
|
||||
if filter.UserID != nil {
|
||||
payload["user_id"] = *filter.UserID
|
||||
}
|
||||
if filter.AccountID != nil {
|
||||
payload["account_id"] = *filter.AccountID
|
||||
}
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
payload["start_time"] = filter.StartTime.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
payload["end_time"] = filter.EndTime.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetSystemLogSinkHealth() OpsSystemLogSinkHealth {
|
||||
if s == nil || s.systemLogSink == nil {
|
||||
return OpsSystemLogSinkHealth{}
|
||||
}
|
||||
return s.systemLogSink.Health()
|
||||
}
|
||||
243
backend/internal/service/ops_system_log_service_test.go
Normal file
243
backend/internal/service/ops_system_log_service_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func TestOpsServiceListSystemLogs_DefaultClampAndSuccess(t *testing.T) {
|
||||
var gotFilter *OpsSystemLogFilter
|
||||
repo := &opsRepoMock{
|
||||
ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) {
|
||||
gotFilter = filter
|
||||
return &OpsSystemLogList{
|
||||
Logs: []*OpsSystemLog{{ID: 1, Level: "warn", Message: "x"}},
|
||||
Total: 1,
|
||||
Page: filter.Page,
|
||||
PageSize: filter.PageSize,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
out, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{
|
||||
Page: 0,
|
||||
PageSize: 999,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListSystemLogs() error: %v", err)
|
||||
}
|
||||
if gotFilter == nil {
|
||||
t.Fatalf("expected repository to receive filter")
|
||||
}
|
||||
if gotFilter.Page != 1 || gotFilter.PageSize != 200 {
|
||||
t.Fatalf("filter normalized unexpectedly: page=%d pageSize=%d", gotFilter.Page, gotFilter.PageSize)
|
||||
}
|
||||
if out.Total != 1 || len(out.Logs) != 1 {
|
||||
t.Fatalf("unexpected result: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceListSystemLogs_MonitoringDisabled(t *testing.T) {
|
||||
svc := NewOpsService(
|
||||
&opsRepoMock{},
|
||||
nil,
|
||||
&config.Config{Ops: config.OpsConfig{Enabled: false}},
|
||||
nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
_, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected disabled error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceListSystemLogs_NilRepoReturnsEmpty(t *testing.T) {
|
||||
svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
out, err := svc.ListSystemLogs(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListSystemLogs() error: %v", err)
|
||||
}
|
||||
if out == nil || out.Page != 1 || out.PageSize != 50 || out.Total != 0 || len(out.Logs) != 0 {
|
||||
t.Fatalf("unexpected nil-repo result: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceListSystemLogs_RepoErrorMapped(t *testing.T) {
|
||||
repo := &opsRepoMock{
|
||||
ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
_, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected mapped internal error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "OPS_SYSTEM_LOG_LIST_FAILED") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_SuccessAndAudit(t *testing.T) {
|
||||
var audit *OpsSystemLogCleanupAudit
|
||||
repo := &opsRepoMock{
|
||||
DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
return 3, nil
|
||||
},
|
||||
InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error {
|
||||
audit = input
|
||||
return nil
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
userID := int64(7)
|
||||
now := time.Now().UTC()
|
||||
filter := &OpsSystemLogCleanupFilter{
|
||||
StartTime: &now,
|
||||
Level: "warn",
|
||||
RequestID: "req-1",
|
||||
ClientRequestID: "creq-1",
|
||||
UserID: &userID,
|
||||
Query: "timeout",
|
||||
}
|
||||
|
||||
deleted, err := svc.CleanupSystemLogs(context.Background(), filter, 99)
|
||||
if err != nil {
|
||||
t.Fatalf("CleanupSystemLogs() error: %v", err)
|
||||
}
|
||||
if deleted != 3 {
|
||||
t.Fatalf("deleted=%d, want 3", deleted)
|
||||
}
|
||||
if audit == nil {
|
||||
t.Fatalf("expected cleanup audit")
|
||||
}
|
||||
if !strings.Contains(audit.Conditions, `"client_request_id":"creq-1"`) {
|
||||
t.Fatalf("audit conditions should include client_request_id: %s", audit.Conditions)
|
||||
}
|
||||
if !strings.Contains(audit.Conditions, `"user_id":7`) {
|
||||
t.Fatalf("audit conditions should include user_id: %s", audit.Conditions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_RepoUnavailableAndInvalidOperator(t *testing.T) {
|
||||
svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 1); err == nil {
|
||||
t.Fatalf("expected repo unavailable error")
|
||||
}
|
||||
|
||||
svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 0); err == nil {
|
||||
t.Fatalf("expected invalid operator error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_FilterRequired(t *testing.T) {
|
||||
repo := &opsRepoMock{
|
||||
DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
return 0, errors.New("cleanup requires at least one filter condition")
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
_, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{}, 1)
|
||||
if err == nil {
|
||||
t.Fatalf("expected filter required error")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "filter") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_InvalidRange(t *testing.T) {
|
||||
repo := &opsRepoMock{}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
start := time.Now().UTC()
|
||||
end := start.Add(-time.Hour)
|
||||
_, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
}, 1)
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid range error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_NoRowsAndInternalError(t *testing.T) {
|
||||
repo := &opsRepoMock{
|
||||
DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
return 0, sql.ErrNoRows
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{
|
||||
RequestID: "req-1",
|
||||
}, 1)
|
||||
if err != nil || deleted != 0 {
|
||||
t.Fatalf("expected no rows shortcut, deleted=%d err=%v", deleted, err)
|
||||
}
|
||||
|
||||
repo.DeleteSystemLogsFn = func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
return 0, errors.New("boom")
|
||||
}
|
||||
if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{
|
||||
RequestID: "req-1",
|
||||
}, 1); err == nil {
|
||||
t.Fatalf("expected internal cleanup error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceCleanupSystemLogs_AuditFailureIgnored(t *testing.T) {
|
||||
repo := &opsRepoMock{
|
||||
DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) {
|
||||
return 5, nil
|
||||
},
|
||||
InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error {
|
||||
return errors.New("audit down")
|
||||
},
|
||||
}
|
||||
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{
|
||||
RequestID: "r1",
|
||||
}, 1)
|
||||
if err != nil || deleted != 5 {
|
||||
t.Fatalf("audit failure should not break cleanup, deleted=%d err=%v", deleted, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalSystemLogCleanupConditions_NilAndMarshalError(t *testing.T) {
|
||||
if got := marshalSystemLogCleanupConditions(nil); got != "{}" {
|
||||
t.Fatalf("nil filter should return {}, got %s", got)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
userID := int64(1)
|
||||
filter := &OpsSystemLogCleanupFilter{
|
||||
StartTime: &now,
|
||||
EndTime: &now,
|
||||
UserID: &userID,
|
||||
}
|
||||
got := marshalSystemLogCleanupConditions(filter)
|
||||
if !strings.Contains(got, `"start_time"`) || !strings.Contains(got, `"user_id":1`) {
|
||||
t.Fatalf("unexpected marshal payload: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsServiceGetSystemLogSinkHealth(t *testing.T) {
|
||||
svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
health := svc.GetSystemLogSinkHealth()
|
||||
if health.QueueCapacity != 0 || health.QueueDepth != 0 {
|
||||
t.Fatalf("unexpected health for nil sink: %+v", health)
|
||||
}
|
||||
|
||||
sink := NewOpsSystemLogSink(&opsRepoMock{})
|
||||
svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||
health = svc.GetSystemLogSinkHealth()
|
||||
if health.QueueCapacity <= 0 {
|
||||
t.Fatalf("expected non-zero queue capacity: %+v", health)
|
||||
}
|
||||
}
|
||||
335
backend/internal/service/ops_system_log_sink.go
Normal file
335
backend/internal/service/ops_system_log_sink.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
type OpsSystemLogSinkHealth struct {
|
||||
QueueDepth int64 `json:"queue_depth"`
|
||||
QueueCapacity int64 `json:"queue_capacity"`
|
||||
DroppedCount uint64 `json:"dropped_count"`
|
||||
WriteFailed uint64 `json:"write_failed_count"`
|
||||
WrittenCount uint64 `json:"written_count"`
|
||||
AvgWriteDelayMs uint64 `json:"avg_write_delay_ms"`
|
||||
LastError string `json:"last_error"`
|
||||
}
|
||||
|
||||
type OpsSystemLogSink struct {
|
||||
opsRepo OpsRepository
|
||||
|
||||
queue chan *logger.LogEvent
|
||||
|
||||
batchSize int
|
||||
flushInterval time.Duration
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
droppedCount uint64
|
||||
writeFailed uint64
|
||||
writtenCount uint64
|
||||
totalDelayNs uint64
|
||||
|
||||
lastError atomic.Value
|
||||
}
|
||||
|
||||
func NewOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &OpsSystemLogSink{
|
||||
opsRepo: opsRepo,
|
||||
queue: make(chan *logger.LogEvent, 5000),
|
||||
batchSize: 200,
|
||||
flushInterval: time.Second,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.lastError.Store("")
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) Start() {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if s == nil || event == nil || !s.shouldIndex(event) {
|
||||
return
|
||||
}
|
||||
if s.ctx != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case s.queue <- event:
|
||||
default:
|
||||
atomic.AddUint64(&s.droppedCount, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) shouldIndex(event *logger.LogEvent) bool {
|
||||
level := strings.ToLower(strings.TrimSpace(event.Level))
|
||||
switch level {
|
||||
case "warn", "warning", "error", "fatal", "panic", "dpanic":
|
||||
return true
|
||||
}
|
||||
|
||||
component := strings.ToLower(strings.TrimSpace(event.Component))
|
||||
// zap 的 LoggerName 往往为空或不等于业务组件名;业务组件名通常以字段 component 透传。
|
||||
if event.Fields != nil {
|
||||
if fc := strings.ToLower(strings.TrimSpace(asString(event.Fields["component"]))); fc != "" {
|
||||
component = fc
|
||||
}
|
||||
}
|
||||
if strings.Contains(component, "http.access") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(component, "audit") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) run() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(s.flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
batch := make([]*logger.LogEvent, 0, s.batchSize)
|
||||
flush := func(baseCtx context.Context) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
started := time.Now()
|
||||
inserted, err := s.flushBatch(baseCtx, batch)
|
||||
delay := time.Since(started)
|
||||
if err != nil {
|
||||
atomic.AddUint64(&s.writeFailed, uint64(len(batch)))
|
||||
s.lastError.Store(err.Error())
|
||||
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"ops system log sink flush failed\" err=%v batch=%d\n",
|
||||
time.Now().Format(time.RFC3339Nano), err, len(batch),
|
||||
)
|
||||
} else {
|
||||
atomic.AddUint64(&s.writtenCount, uint64(inserted))
|
||||
atomic.AddUint64(&s.totalDelayNs, uint64(delay.Nanoseconds()))
|
||||
s.lastError.Store("")
|
||||
}
|
||||
batch = batch[:0]
|
||||
}
|
||||
drainAndFlush := func() {
|
||||
for {
|
||||
select {
|
||||
case item := <-s.queue:
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
batch = append(batch, item)
|
||||
if len(batch) >= s.batchSize {
|
||||
flush(context.Background())
|
||||
}
|
||||
default:
|
||||
flush(context.Background())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
drainAndFlush()
|
||||
return
|
||||
case item := <-s.queue:
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
batch = append(batch, item)
|
||||
if len(batch) >= s.batchSize {
|
||||
flush(s.ctx)
|
||||
}
|
||||
case <-ticker.C:
|
||||
flush(s.ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) flushBatch(baseCtx context.Context, batch []*logger.LogEvent) (int, error) {
|
||||
inputs := make([]*OpsInsertSystemLogInput, 0, len(batch))
|
||||
for _, event := range batch {
|
||||
if event == nil {
|
||||
continue
|
||||
}
|
||||
createdAt := event.Time.UTC()
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
|
||||
fields := copyMap(event.Fields)
|
||||
requestID := asString(fields["request_id"])
|
||||
clientRequestID := asString(fields["client_request_id"])
|
||||
platform := asString(fields["platform"])
|
||||
model := asString(fields["model"])
|
||||
component := strings.TrimSpace(event.Component)
|
||||
if fieldComponent := asString(fields["component"]); fieldComponent != "" {
|
||||
component = fieldComponent
|
||||
}
|
||||
if component == "" {
|
||||
component = "app"
|
||||
}
|
||||
|
||||
userID := asInt64Ptr(fields["user_id"])
|
||||
accountID := asInt64Ptr(fields["account_id"])
|
||||
|
||||
// 统一脱敏后写入索引。
|
||||
message := logredact.RedactText(strings.TrimSpace(event.Message))
|
||||
redactedExtra := logredact.RedactMap(fields)
|
||||
extraJSONBytes, _ := json.Marshal(redactedExtra)
|
||||
extraJSON := string(extraJSONBytes)
|
||||
if strings.TrimSpace(extraJSON) == "" {
|
||||
extraJSON = "{}"
|
||||
}
|
||||
|
||||
inputs = append(inputs, &OpsInsertSystemLogInput{
|
||||
CreatedAt: createdAt,
|
||||
Level: strings.ToLower(strings.TrimSpace(event.Level)),
|
||||
Component: component,
|
||||
Message: message,
|
||||
RequestID: requestID,
|
||||
ClientRequestID: clientRequestID,
|
||||
UserID: userID,
|
||||
AccountID: accountID,
|
||||
Platform: platform,
|
||||
Model: model,
|
||||
ExtraJSON: extraJSON,
|
||||
})
|
||||
}
|
||||
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if baseCtx == nil || baseCtx.Err() != nil {
|
||||
baseCtx = context.Background()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(baseCtx, 5*time.Second)
|
||||
defer cancel()
|
||||
inserted, err := s.opsRepo.BatchInsertSystemLogs(ctx, inputs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(inserted), nil
|
||||
}
|
||||
|
||||
func (s *OpsSystemLogSink) Health() OpsSystemLogSinkHealth {
|
||||
if s == nil {
|
||||
return OpsSystemLogSinkHealth{}
|
||||
}
|
||||
written := atomic.LoadUint64(&s.writtenCount)
|
||||
totalDelay := atomic.LoadUint64(&s.totalDelayNs)
|
||||
var avgDelay uint64
|
||||
if written > 0 {
|
||||
avgDelay = (totalDelay / written) / uint64(time.Millisecond)
|
||||
}
|
||||
|
||||
lastErr, _ := s.lastError.Load().(string)
|
||||
return OpsSystemLogSinkHealth{
|
||||
QueueDepth: int64(len(s.queue)),
|
||||
QueueCapacity: int64(cap(s.queue)),
|
||||
DroppedCount: atomic.LoadUint64(&s.droppedCount),
|
||||
WriteFailed: atomic.LoadUint64(&s.writeFailed),
|
||||
WrittenCount: written,
|
||||
AvgWriteDelayMs: avgDelay,
|
||||
LastError: strings.TrimSpace(lastErr),
|
||||
}
|
||||
}
|
||||
|
||||
func copyMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(t)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(t.String())
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func asInt64Ptr(v any) *int64 {
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
n := int64(t)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &n
|
||||
case int64:
|
||||
n := t
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &n
|
||||
case float64:
|
||||
n := int64(t)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &n
|
||||
case json.Number:
|
||||
if n, err := t.Int64(); err == nil {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &n
|
||||
}
|
||||
case string:
|
||||
raw := strings.TrimSpace(t)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if n, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
313
backend/internal/service/ops_system_log_sink_test.go
Normal file
313
backend/internal/service/ops_system_log_sink_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
func TestOpsSystemLogSink_ShouldIndex(t *testing.T) {
|
||||
sink := &OpsSystemLogSink{}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
event *logger.LogEvent
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "warn level",
|
||||
event: &logger.LogEvent{Level: "warn", Component: "app"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "error level",
|
||||
event: &logger.LogEvent{Level: "error", Component: "app"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "access component",
|
||||
event: &logger.LogEvent{Level: "info", Component: "http.access"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "access component from fields (real zap path)",
|
||||
event: &logger.LogEvent{
|
||||
Level: "info",
|
||||
Component: "",
|
||||
Fields: map[string]any{"component": "http.access"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "audit component",
|
||||
event: &logger.LogEvent{Level: "info", Component: "audit.log_config_change"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "audit component from fields (real zap path)",
|
||||
event: &logger.LogEvent{
|
||||
Level: "info",
|
||||
Component: "",
|
||||
Fields: map[string]any{"component": "audit.log_config_change"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "plain info",
|
||||
event: &logger.LogEvent{Level: "info", Component: "app"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if got := sink.shouldIndex(tc.event); got != tc.want {
|
||||
t.Fatalf("%s: shouldIndex()=%v, want %v", tc.name, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogSink_WriteLogEvent_ShouldDropWhenQueueFull(t *testing.T) {
|
||||
sink := &OpsSystemLogSink{
|
||||
queue: make(chan *logger.LogEvent, 1),
|
||||
}
|
||||
|
||||
sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"})
|
||||
sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"})
|
||||
|
||||
if got := len(sink.queue); got != 1 {
|
||||
t.Fatalf("queue len = %d, want 1", got)
|
||||
}
|
||||
if dropped := atomic.LoadUint64(&sink.droppedCount); dropped != 1 {
|
||||
t.Fatalf("droppedCount = %d, want 1", dropped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogSink_Health(t *testing.T) {
|
||||
sink := &OpsSystemLogSink{
|
||||
queue: make(chan *logger.LogEvent, 10),
|
||||
}
|
||||
sink.lastError.Store("db timeout")
|
||||
atomic.StoreUint64(&sink.droppedCount, 3)
|
||||
atomic.StoreUint64(&sink.writeFailed, 2)
|
||||
atomic.StoreUint64(&sink.writtenCount, 5)
|
||||
atomic.StoreUint64(&sink.totalDelayNs, uint64(5000000)) // 5ms total -> avg 1ms
|
||||
sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"}
|
||||
sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"}
|
||||
|
||||
health := sink.Health()
|
||||
if health.QueueDepth != 2 {
|
||||
t.Fatalf("queue depth = %d, want 2", health.QueueDepth)
|
||||
}
|
||||
if health.QueueCapacity != 10 {
|
||||
t.Fatalf("queue capacity = %d, want 10", health.QueueCapacity)
|
||||
}
|
||||
if health.DroppedCount != 3 {
|
||||
t.Fatalf("dropped = %d, want 3", health.DroppedCount)
|
||||
}
|
||||
if health.WriteFailed != 2 {
|
||||
t.Fatalf("write failed = %d, want 2", health.WriteFailed)
|
||||
}
|
||||
if health.WrittenCount != 5 {
|
||||
t.Fatalf("written = %d, want 5", health.WrittenCount)
|
||||
}
|
||||
if health.AvgWriteDelayMs != 1 {
|
||||
t.Fatalf("avg delay ms = %d, want 1", health.AvgWriteDelayMs)
|
||||
}
|
||||
if health.LastError != "db timeout" {
|
||||
t.Fatalf("last error = %q, want db timeout", health.LastError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
|
||||
done := make(chan struct{}, 1)
|
||||
var captured []*OpsInsertSystemLogInput
|
||||
repo := &opsRepoMock{
|
||||
BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) {
|
||||
captured = append(captured, inputs...)
|
||||
select {
|
||||
case done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return int64(len(inputs)), nil
|
||||
},
|
||||
}
|
||||
|
||||
sink := NewOpsSystemLogSink(repo)
|
||||
sink.batchSize = 1
|
||||
sink.flushInterval = 10 * time.Millisecond
|
||||
sink.Start()
|
||||
defer sink.Stop()
|
||||
|
||||
sink.WriteLogEvent(&logger.LogEvent{
|
||||
Time: time.Now().UTC(),
|
||||
Level: "warn",
|
||||
Component: "http.access",
|
||||
Message: `authorization="Bearer sk-test-123"`,
|
||||
Fields: map[string]any{
|
||||
"component": "http.access",
|
||||
"request_id": "req-1",
|
||||
"client_request_id": "creq-1",
|
||||
"user_id": "12",
|
||||
"account_id": json.Number("34"),
|
||||
"platform": "openai",
|
||||
"model": "gpt-5",
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timeout waiting for sink flush")
|
||||
}
|
||||
|
||||
if len(captured) != 1 {
|
||||
t.Fatalf("captured len = %d, want 1", len(captured))
|
||||
}
|
||||
item := captured[0]
|
||||
if item.RequestID != "req-1" || item.ClientRequestID != "creq-1" {
|
||||
t.Fatalf("unexpected request ids: %+v", item)
|
||||
}
|
||||
if item.UserID == nil || *item.UserID != 12 {
|
||||
t.Fatalf("unexpected user_id: %+v", item.UserID)
|
||||
}
|
||||
if item.AccountID == nil || *item.AccountID != 34 {
|
||||
t.Fatalf("unexpected account_id: %+v", item.AccountID)
|
||||
}
|
||||
if strings.TrimSpace(item.Message) == "" {
|
||||
t.Fatalf("message should not be empty")
|
||||
}
|
||||
health := sink.Health()
|
||||
if health.WrittenCount == 0 {
|
||||
t.Fatalf("written_count should be >0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogSink_FlushFailureUpdatesHealth(t *testing.T) {
|
||||
repo := &opsRepoMock{
|
||||
BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) {
|
||||
return 0, errors.New("db unavailable")
|
||||
},
|
||||
}
|
||||
sink := NewOpsSystemLogSink(repo)
|
||||
sink.batchSize = 1
|
||||
sink.flushInterval = 10 * time.Millisecond
|
||||
sink.Start()
|
||||
defer sink.Stop()
|
||||
|
||||
sink.WriteLogEvent(&logger.LogEvent{
|
||||
Time: time.Now().UTC(),
|
||||
Level: "warn",
|
||||
Component: "app",
|
||||
Message: "boom",
|
||||
Fields: map[string]any{},
|
||||
})
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
health := sink.Health()
|
||||
if health.WriteFailed > 0 {
|
||||
if !strings.Contains(health.LastError, "db unavailable") {
|
||||
t.Fatalf("unexpected last error: %s", health.LastError)
|
||||
}
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("write_failed_count not updated")
|
||||
}
|
||||
|
||||
func TestOpsSystemLogSink_StopFlushUsesActiveContextAndDrainsQueue(t *testing.T) {
|
||||
var inserted int64
|
||||
var canceledCtxCalls int64
|
||||
repo := &opsRepoMock{
|
||||
BatchInsertSystemLogsFn: func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
atomic.AddInt64(&canceledCtxCalls, 1)
|
||||
return 0, err
|
||||
}
|
||||
atomic.AddInt64(&inserted, int64(len(inputs)))
|
||||
return int64(len(inputs)), nil
|
||||
},
|
||||
}
|
||||
|
||||
sink := NewOpsSystemLogSink(repo)
|
||||
sink.batchSize = 200
|
||||
sink.flushInterval = time.Hour
|
||||
sink.Start()
|
||||
|
||||
sink.WriteLogEvent(&logger.LogEvent{
|
||||
Time: time.Now().UTC(),
|
||||
Level: "warn",
|
||||
Component: "app",
|
||||
Message: "pending-on-shutdown",
|
||||
Fields: map[string]any{"component": "http.access"},
|
||||
})
|
||||
|
||||
sink.Stop()
|
||||
|
||||
if got := atomic.LoadInt64(&inserted); got != 1 {
|
||||
t.Fatalf("inserted = %d, want 1", got)
|
||||
}
|
||||
if got := atomic.LoadInt64(&canceledCtxCalls); got != 0 {
|
||||
t.Fatalf("canceled ctx calls = %d, want 0", got)
|
||||
}
|
||||
health := sink.Health()
|
||||
if health.WrittenCount != 1 {
|
||||
t.Fatalf("written_count = %d, want 1", health.WrittenCount)
|
||||
}
|
||||
}
|
||||
|
||||
type stringerValue string
|
||||
|
||||
func (s stringerValue) String() string { return string(s) }
|
||||
|
||||
func TestOpsSystemLogSink_HelperFunctions(t *testing.T) {
|
||||
src := map[string]any{"a": 1}
|
||||
cloned := copyMap(src)
|
||||
src["a"] = 2
|
||||
v, ok := cloned["a"].(int)
|
||||
if !ok || v != 1 {
|
||||
t.Fatalf("copyMap should create copy")
|
||||
}
|
||||
if got := asString(stringerValue(" hello ")); got != "hello" {
|
||||
t.Fatalf("asString stringer = %q", got)
|
||||
}
|
||||
if got := asString(fmt.Errorf("x")); got != "" {
|
||||
t.Fatalf("asString error should be empty, got %q", got)
|
||||
}
|
||||
if got := asString(123); got != "" {
|
||||
t.Fatalf("asString non-string should be empty, got %q", got)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
in any
|
||||
want int64
|
||||
ok bool
|
||||
}{
|
||||
{in: 5, want: 5, ok: true},
|
||||
{in: int64(6), want: 6, ok: true},
|
||||
{in: float64(7), want: 7, ok: true},
|
||||
{in: json.Number("8"), want: 8, ok: true},
|
||||
{in: "9", want: 9, ok: true},
|
||||
{in: "0", ok: false},
|
||||
{in: -1, ok: false},
|
||||
{in: "abc", ok: false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := asInt64Ptr(tc.in)
|
||||
if tc.ok {
|
||||
if got == nil || *got != tc.want {
|
||||
t.Fatalf("asInt64Ptr(%v) = %+v, want %d", tc.in, got, tc.want)
|
||||
}
|
||||
} else if got != nil {
|
||||
t.Fatalf("asInt64Ptr(%v) should be nil, got %d", tc.in, *got)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -84,12 +84,12 @@ func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *Pr
|
||||
func (s *PricingService) Initialize() error {
|
||||
// 确保数据目录存在
|
||||
if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil {
|
||||
log.Printf("[Pricing] Failed to create data directory: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to create data directory: %v", err)
|
||||
}
|
||||
|
||||
// 首次加载价格数据
|
||||
if err := s.checkAndUpdatePricing(); err != nil {
|
||||
log.Printf("[Pricing] Initial load failed, using fallback: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Initial load failed, using fallback: %v", err)
|
||||
if err := s.useFallbackPricing(); err != nil {
|
||||
return fmt.Errorf("failed to load pricing data: %w", err)
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (s *PricingService) Initialize() error {
|
||||
// 启动定时更新
|
||||
s.startUpdateScheduler()
|
||||
|
||||
log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData))
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Service initialized with %d models", len(s.pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (s *PricingService) Initialize() error {
|
||||
func (s *PricingService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[Pricing] Service stopped")
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Service stopped")
|
||||
}
|
||||
|
||||
// startUpdateScheduler 启动定时更新调度器
|
||||
@@ -127,7 +127,7 @@ func (s *PricingService) startUpdateScheduler() {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.syncWithRemote(); err != nil {
|
||||
log.Printf("[Pricing] Sync failed: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Sync failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
@@ -135,7 +135,7 @@ func (s *PricingService) startUpdateScheduler() {
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Update scheduler started (check every %v)", hashInterval)
|
||||
}
|
||||
|
||||
// checkAndUpdatePricing 检查并更新价格数据
|
||||
@@ -144,7 +144,7 @@ func (s *PricingService) checkAndUpdatePricing() error {
|
||||
|
||||
// 检查本地文件是否存在
|
||||
if _, err := os.Stat(pricingFile); os.IsNotExist(err) {
|
||||
log.Println("[Pricing] Local pricing file not found, downloading...")
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Local pricing file not found, downloading...")
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
@@ -158,9 +158,9 @@ func (s *PricingService) checkAndUpdatePricing() error {
|
||||
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
|
||||
|
||||
if fileAge > maxAge {
|
||||
log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
|
||||
if err := s.downloadPricingData(); err != nil {
|
||||
log.Printf("[Pricing] Download failed, using existing file: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ func (s *PricingService) syncWithRemote() error {
|
||||
// 计算本地文件哈希
|
||||
localHash, err := s.computeFileHash(pricingFile)
|
||||
if err != nil {
|
||||
log.Printf("[Pricing] Failed to compute local hash: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err)
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
@@ -183,15 +183,15 @@ func (s *PricingService) syncWithRemote() error {
|
||||
if s.cfg.Pricing.HashURL != "" {
|
||||
remoteHash, err := s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
log.Printf("[Pricing] Failed to fetch remote hash: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash: %v", err)
|
||||
return nil // 哈希获取失败不影响正常使用
|
||||
}
|
||||
|
||||
if remoteHash != localHash {
|
||||
log.Println("[Pricing] Remote hash differs, downloading new version...")
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...")
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
log.Println("[Pricing] Hash check passed, no update needed")
|
||||
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ func (s *PricingService) syncWithRemote() error {
|
||||
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
|
||||
|
||||
if fileAge > maxAge {
|
||||
log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ func (s *PricingService) downloadPricingData() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[Pricing] Downloading from %s", remoteURL)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Downloading from %s", remoteURL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
@@ -252,7 +252,7 @@ func (s *PricingService) downloadPricingData() error {
|
||||
// 保存到本地文件
|
||||
pricingFile := s.getPricingFilePath()
|
||||
if err := os.WriteFile(pricingFile, body, 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to save file: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err)
|
||||
}
|
||||
|
||||
// 保存哈希
|
||||
@@ -260,7 +260,7 @@ func (s *PricingService) downloadPricingData() error {
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
hashFile := s.getHashFilePath()
|
||||
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to save hash: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err)
|
||||
}
|
||||
|
||||
// 更新内存数据
|
||||
@@ -270,7 +270,7 @@ func (s *PricingService) downloadPricingData() error {
|
||||
s.localHash = hashStr
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("[Pricing] Downloaded %d models successfully", len(data))
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -329,7 +329,7 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
}
|
||||
|
||||
if skipped > 0 {
|
||||
log.Printf("[Pricing] Skipped %d invalid entries", skipped)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Skipped %d invalid entries", skipped)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
@@ -368,7 +368,7 @@ func (s *PricingService) loadPricingData(filePath string) error {
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Loaded %d models from %s", len(pricingData), filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -380,7 +380,7 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
return fmt.Errorf("fallback file not found: %s", fallbackFile)
|
||||
}
|
||||
|
||||
log.Printf("[Pricing] Using fallback file: %s", fallbackFile)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Using fallback file: %s", fallbackFile)
|
||||
|
||||
// 复制到数据目录
|
||||
data, err := os.ReadFile(fallbackFile)
|
||||
@@ -390,7 +390,7 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
pricingFile := s.getPricingFilePath()
|
||||
if err := os.WriteFile(pricingFile, data, 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to copy fallback: %v", err)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to copy fallback: %v", err)
|
||||
}
|
||||
|
||||
return s.loadPricingData(fallbackFile)
|
||||
@@ -639,7 +639,7 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
for key, pricing := range s.pricingData {
|
||||
keyLower := strings.ToLower(key)
|
||||
if strings.Contains(keyLower, pattern) {
|
||||
log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] Fuzzy matched %s -> %s", model, key)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
@@ -660,14 +660,14 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
|
||||
for _, variant := range variants {
|
||||
if pricing, ok := s.pricingData[variant]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback matched %s -> %s", model, variant)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gpt-5.3-codex") {
|
||||
if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
@@ -675,7 +675,7 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel)
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel)
|
||||
return pricing
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -104,7 +104,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context,
|
||||
if s.cache != nil {
|
||||
cached, hit, err := s.cache.GetSnapshot(ctx, bucket)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
|
||||
} else if hit {
|
||||
return derefAccounts(cached), useMixed, nil
|
||||
}
|
||||
@@ -124,7 +124,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context,
|
||||
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int
|
||||
if s.cache != nil {
|
||||
account, err := s.cache.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
|
||||
} else if account != nil {
|
||||
return account, nil
|
||||
}
|
||||
@@ -168,17 +168,17 @@ func (s *SchedulerSnapshotService) runInitialRebuild() {
|
||||
defer cancel()
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err)
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil {
|
||||
log.Printf("[Scheduler] rebuild startup failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild startup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration)
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.triggerFullRebuild("interval"); err != nil {
|
||||
log.Printf("[Scheduler] full rebuild failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] full rebuild failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
@@ -222,13 +222,13 @@ func (s *SchedulerSnapshotService) pollOutbox() {
|
||||
|
||||
watermark, err := s.cache.GetOutboxWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark read failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
events, err := s.outboxRepo.ListAfter(ctx, watermark, 200)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox poll failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
if len(events) == 0 {
|
||||
@@ -241,14 +241,14 @@ func (s *SchedulerSnapshotService) pollOutbox() {
|
||||
err := s.handleOutboxEvent(eventCtx, event)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
lastID := events[len(events)-1].ID
|
||||
if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark write failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err)
|
||||
} else {
|
||||
watermarkForCheck = lastID
|
||||
}
|
||||
@@ -445,11 +445,11 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
slog.Debug("[Scheduler] rebuild ok", "bucket", bucket.String(), "reason", reason, "size", len(accounts))
|
||||
@@ -465,13 +465,13 @@ func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error {
|
||||
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -485,7 +485,7 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc
|
||||
|
||||
lag := time.Since(oldest.CreatedAt)
|
||||
if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 {
|
||||
log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag warning: %ds", lagSeconds)
|
||||
}
|
||||
|
||||
if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds {
|
||||
@@ -495,12 +495,12 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc
|
||||
s.lagMu.Unlock()
|
||||
|
||||
if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures {
|
||||
log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
if err := s.triggerFullRebuild("outbox_lag"); err != nil {
|
||||
log.Printf("[Scheduler] outbox lag rebuild failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -518,9 +518,9 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc
|
||||
return
|
||||
}
|
||||
if maxID-watermark >= int64(threshold) {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
|
||||
if err := s.triggerFullRebuild("outbox_backlog"); err != nil {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err)
|
||||
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -9,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
@@ -37,18 +37,18 @@ func (s *SoraMediaCleanupService) Start() {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Sora.Storage.Cleanup.Enabled {
|
||||
log.Printf("[SoraCleanup] not started (disabled)")
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.storage == nil || !s.storage.Enabled() {
|
||||
log.Printf("[SoraCleanup] not started (storage disabled)")
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (storage disabled)")
|
||||
return
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
|
||||
if schedule == "" {
|
||||
log.Printf("[SoraCleanup] not started (empty schedule)")
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (empty schedule)")
|
||||
return
|
||||
}
|
||||
loc := time.Local
|
||||
@@ -59,12 +59,12 @@ func (s *SoraMediaCleanupService) Start() {
|
||||
}
|
||||
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
|
||||
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
|
||||
log.Printf("[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
log.Printf("[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (s *SoraMediaCleanupService) Stop() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Printf("[SoraCleanup] cron stop timed out")
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -90,7 +90,7 @@ func (s *SoraMediaCleanupService) runCleanup() {
|
||||
}
|
||||
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
||||
if retention <= 0 {
|
||||
log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention)
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] skipped (retention_days=%d)", retention)
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -retention)
|
||||
@@ -116,5 +116,5 @@ func (s *SoraMediaCleanupService) runCleanup() {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
log.Printf("[SoraCleanup] cleanup finished, deleted=%d", deleted)
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cleanup finished, deleted=%d", deleted)
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
@@ -34,21 +34,21 @@ func NewTimingWheelService() (*TimingWheelService, error) {
|
||||
|
||||
// Start starts the timing wheel
|
||||
func (s *TimingWheelService) Start() {
|
||||
log.Println("[TimingWheel] Started (auto-start by go-zero)")
|
||||
logger.LegacyPrintf("service.timing_wheel", "%s", "[TimingWheel] Started (auto-start by go-zero)")
|
||||
}
|
||||
|
||||
// Stop stops the timing wheel
|
||||
func (s *TimingWheelService) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
s.tw.Stop()
|
||||
log.Println("[TimingWheel] Stopped")
|
||||
logger.LegacyPrintf("service.timing_wheel", "%s", "[TimingWheel] Stopped")
|
||||
})
|
||||
}
|
||||
|
||||
// Schedule schedules a one-time task
|
||||
func (s *TimingWheelService) Schedule(name string, delay time.Duration, fn func()) {
|
||||
if err := s.tw.SetTimer(name, fn, delay); err != nil {
|
||||
log.Printf("[TimingWheel] SetTimer failed for %q: %v", name, err)
|
||||
logger.LegacyPrintf("service.timing_wheel", "[TimingWheel] SetTimer failed for %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,11 +58,11 @@ func (s *TimingWheelService) ScheduleRecurring(name string, interval time.Durati
|
||||
schedule = func() {
|
||||
fn()
|
||||
if err := s.tw.SetTimer(name, schedule, interval); err != nil {
|
||||
log.Printf("[TimingWheel] recurring SetTimer failed for %q: %v", name, err)
|
||||
logger.LegacyPrintf("service.timing_wheel", "[TimingWheel] recurring SetTimer failed for %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
if err := s.tw.SetTimer(name, schedule, interval); err != nil {
|
||||
log.Printf("[TimingWheel] initial SetTimer failed for %q: %v", name, err)
|
||||
logger.LegacyPrintf("service.timing_wheel", "[TimingWheel] initial SetTimer failed for %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -70,22 +69,24 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
log.Println("[TokenRefresh] Service disabled by configuration")
|
||||
slog.Info("token_refresh.service_disabled")
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.refreshLoop()
|
||||
|
||||
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
|
||||
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
|
||||
slog.Info("token_refresh.service_started",
|
||||
"check_interval_minutes", s.cfg.CheckIntervalMinutes,
|
||||
"refresh_before_expiry_hours", s.cfg.RefreshBeforeExpiryHours,
|
||||
)
|
||||
}
|
||||
|
||||
// Stop 停止刷新服务
|
||||
func (s *TokenRefreshService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[TokenRefresh] Service stopped")
|
||||
slog.Info("token_refresh.service_stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
@@ -124,7 +125,7 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
// 获取所有active状态的账号
|
||||
accounts, err := s.listActiveAccounts(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
|
||||
slog.Error("token_refresh.list_accounts_failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -153,10 +154,17 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
slog.Warn("token_refresh.account_refresh_failed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
"error", err,
|
||||
)
|
||||
failed++
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
|
||||
slog.Info("token_refresh.account_refreshed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
)
|
||||
refreshed++
|
||||
}
|
||||
|
||||
@@ -167,12 +175,17 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
|
||||
// 无刷新活动时降级为 Debug,有实际刷新活动时保持 Info
|
||||
if needsRefresh == 0 && failed == 0 {
|
||||
slog.Debug("[TokenRefresh] Cycle complete",
|
||||
slog.Debug("token_refresh.cycle_completed",
|
||||
"total", totalAccounts, "oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Cycle complete: total=%d, oauth=%d, needs_refresh=%d, refreshed=%d, failed=%d",
|
||||
totalAccounts, oauthAccounts, needsRefresh, refreshed, failed)
|
||||
slog.Info("token_refresh.cycle_completed",
|
||||
"total", totalAccounts,
|
||||
"oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh,
|
||||
"refreshed", refreshed,
|
||||
"failed", failed,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,26 +220,35 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
|
||||
slog.Warn("token_refresh.clear_account_error_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
|
||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
|
||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
|
||||
if s.schedulerCache != nil {
|
||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to sync scheduler cache for account %d: %v", account.ID, err)
|
||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Scheduler cache synced for account %d", account.ID)
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -236,14 +258,21 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
||||
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, setErr)
|
||||
slog.Error("token_refresh.set_error_status_failed",
|
||||
"account_id", account.ID,
|
||||
"error", setErr,
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||
slog.Warn("token_refresh.retry_attempt_failed",
|
||||
"account_id", account.ID,
|
||||
"attempt", attempt,
|
||||
"max_retries", s.cfg.MaxRetries,
|
||||
"error", err,
|
||||
)
|
||||
|
||||
// 如果还有重试机会,等待后重试
|
||||
if attempt < s.cfg.MaxRetries {
|
||||
@@ -256,11 +285,18 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
// Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题)
|
||||
// 其他平台账户:重试失败后标记 error
|
||||
if account.Platform == PlatformAntigravity {
|
||||
log.Printf("[TokenRefresh] Account %d: refresh failed after %d retries: %v", account.ID, s.cfg.MaxRetries, lastErr)
|
||||
slog.Warn("token_refresh.retry_exhausted_antigravity",
|
||||
"account_id", account.ID,
|
||||
"max_retries", s.cfg.MaxRetries,
|
||||
"error", lastErr,
|
||||
)
|
||||
} else {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
slog.Error("token_refresh.set_error_status_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -47,36 +47,36 @@ func NewTurnstileService(settingService *SettingService, verifier TurnstileVerif
|
||||
func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remoteIP string) error {
|
||||
// 检查是否启用 Turnstile
|
||||
if !s.settingService.IsTurnstileEnabled(ctx) {
|
||||
log.Println("[Turnstile] Disabled, skipping verification")
|
||||
logger.LegacyPrintf("service.turnstile", "%s", "[Turnstile] Disabled, skipping verification")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取 Secret Key
|
||||
secretKey := s.settingService.GetTurnstileSecretKey(ctx)
|
||||
if secretKey == "" {
|
||||
log.Println("[Turnstile] Secret key not configured")
|
||||
logger.LegacyPrintf("service.turnstile", "%s", "[Turnstile] Secret key not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
|
||||
// 如果 token 为空,返回错误
|
||||
if token == "" {
|
||||
log.Println("[Turnstile] Token is empty")
|
||||
logger.LegacyPrintf("service.turnstile", "%s", "[Turnstile] Token is empty")
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||
logger.LegacyPrintf("service.turnstile", "[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||
result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
|
||||
if err != nil {
|
||||
log.Printf("[Turnstile] Request failed: %v", err)
|
||||
logger.LegacyPrintf("service.turnstile", "[Turnstile] Request failed: %v", err)
|
||||
return fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||
logger.LegacyPrintf("service.turnstile", "[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
log.Println("[Turnstile] Verification successful")
|
||||
logger.LegacyPrintf("service.turnstile", "%s", "[Turnstile] Verification successful")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -82,18 +82,18 @@ func (s *UsageCleanupService) Start() {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
log.Printf("[UsageCleanup] not started (disabled)")
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.repo == nil || s.timingWheel == nil {
|
||||
log.Printf("[UsageCleanup] not started (missing deps)")
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
interval := s.workerInterval()
|
||||
s.startOnce.Do(func() {
|
||||
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
|
||||
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func (s *UsageCleanupService) Stop() {
|
||||
if s.timingWheel != nil {
|
||||
s.timingWheel.Cancel(usageCleanupWorkerName)
|
||||
}
|
||||
log.Printf("[UsageCleanup] stopped")
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] stopped")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -130,10 +130,10 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean
|
||||
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
|
||||
sanitizeUsageCleanupFilters(&filters)
|
||||
if err := s.validateFilters(filters); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -143,10 +143,10 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
if err := s.repo.CreateTask(ctx, task); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, fmt.Errorf("create cleanup task: %w", err)
|
||||
}
|
||||
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
|
||||
go s.runOnce()
|
||||
return task, nil
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func (s *UsageCleanupService) runOnce() {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) {
|
||||
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] run_once skipped: already_running=true")
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&svc.running, 0)
|
||||
@@ -171,7 +171,7 @@ func (s *UsageCleanupService) runOnce() {
|
||||
|
||||
task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds()))
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] claim pending task failed: %v", err)
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
@@ -179,7 +179,7 @@ func (s *UsageCleanupService) runOnce() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
|
||||
svc.executeTask(ctx, task)
|
||||
}
|
||||
|
||||
@@ -191,12 +191,12 @@ func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanu
|
||||
batchSize := s.batchSize()
|
||||
deletedTotal := task.DeletedRows
|
||||
start := time.Now()
|
||||
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
|
||||
var batchNum int
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
|
||||
return
|
||||
}
|
||||
canceled, err := s.isTaskCanceled(ctx, task.ID)
|
||||
@@ -205,7 +205,7 @@ func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanu
|
||||
return
|
||||
}
|
||||
if canceled {
|
||||
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanu
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
|
||||
return
|
||||
}
|
||||
s.markTaskFailed(task.ID, deletedTotal, err)
|
||||
@@ -224,12 +224,12 @@ func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanu
|
||||
if deleted > 0 {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
|
||||
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
|
||||
}
|
||||
if deleted == 0 || deleted < int64(batchSize) {
|
||||
break
|
||||
@@ -239,16 +239,16 @@ func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanu
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
}
|
||||
|
||||
if s.dashboard != nil {
|
||||
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,11 +258,11 @@ func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, er
|
||||
if len(msg) > 500 {
|
||||
msg = msg[:500]
|
||||
}
|
||||
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
|
||||
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64)
|
||||
return false, err
|
||||
}
|
||||
if status == UsageCleanupStatusCanceled {
|
||||
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] task cancel detected: task=%d", taskID)
|
||||
}
|
||||
return status == UsageCleanupStatusCanceled, nil
|
||||
}
|
||||
@@ -319,7 +319,7 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
@@ -331,7 +331,7 @@ func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canc
|
||||
// 状态可能并发改变
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
logger.LegacyPrintf("service.usage_cleanup", "[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@@ -193,6 +194,13 @@ func ProvideOpsCleanupService(
|
||||
return svc
|
||||
}
|
||||
|
||||
func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
|
||||
sink := NewOpsSystemLogSink(opsRepo)
|
||||
sink.Start()
|
||||
logger.SetSink(sink)
|
||||
return sink
|
||||
}
|
||||
|
||||
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
|
||||
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
@@ -268,6 +276,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
NewSettingService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
ProvideOpsMetricsCollector,
|
||||
ProvideOpsAggregationService,
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -103,6 +104,36 @@ type JWTConfig struct {
|
||||
ExpireHour int `json:"expire_hour" yaml:"expire_hour"`
|
||||
}
|
||||
|
||||
const (
|
||||
adminBootstrapReasonEmptyDatabase = "empty_database"
|
||||
adminBootstrapReasonAdminExists = "admin_exists"
|
||||
adminBootstrapReasonUsersExistWithoutAdmin = "users_exist_without_admin"
|
||||
)
|
||||
|
||||
type adminBootstrapDecision struct {
|
||||
shouldCreate bool
|
||||
reason string
|
||||
}
|
||||
|
||||
func decideAdminBootstrap(totalUsers, adminUsers int64) adminBootstrapDecision {
|
||||
if adminUsers > 0 {
|
||||
return adminBootstrapDecision{
|
||||
shouldCreate: false,
|
||||
reason: adminBootstrapReasonAdminExists,
|
||||
}
|
||||
}
|
||||
if totalUsers > 0 {
|
||||
return adminBootstrapDecision{
|
||||
shouldCreate: false,
|
||||
reason: adminBootstrapReasonUsersExistWithoutAdmin,
|
||||
}
|
||||
}
|
||||
return adminBootstrapDecision{
|
||||
shouldCreate: true,
|
||||
reason: adminBootstrapReasonEmptyDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
// NeedsSetup checks if the system needs initial setup
|
||||
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
|
||||
func NeedsSetup() bool {
|
||||
@@ -137,7 +168,7 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
return
|
||||
}
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -164,12 +195,12 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err)
|
||||
}
|
||||
log.Printf("Database '%s' created successfully", cfg.DBName)
|
||||
logger.LegacyPrintf("setup", "Database '%s' created successfully", cfg.DBName)
|
||||
}
|
||||
|
||||
// Now connect to the target database to verify
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close postgres connection: %v", err)
|
||||
}
|
||||
db = nil
|
||||
|
||||
@@ -185,7 +216,7 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
|
||||
defer func() {
|
||||
if err := targetDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -217,7 +248,7 @@ func TestRedisConnection(cfg *RedisConfig) error {
|
||||
rdb := redis.NewClient(opts)
|
||||
defer func() {
|
||||
if err := rdb.Close(); err != nil {
|
||||
log.Printf("failed to close redis client: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close redis client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -245,7 +276,7 @@ func Install(cfg *SetupConfig) error {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
logger.LegacyPrintf("setup", "%s", "Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Test connections
|
||||
@@ -262,8 +293,8 @@ func Install(cfg *SetupConfig) error {
|
||||
return fmt.Errorf("database initialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
if err := createAdminUser(cfg); err != nil {
|
||||
// Create admin user (only when database is empty and no admin exists).
|
||||
if _, _, err := createAdminUser(cfg); err != nil {
|
||||
return fmt.Errorf("admin user creation failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -300,7 +331,7 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -309,7 +340,7 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
return repository.ApplyMigrations(migrationCtx, db)
|
||||
}
|
||||
|
||||
func createAdminUser(cfg *SetupConfig) error {
|
||||
func createAdminUser(cfg *SetupConfig) (bool, string, error) {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
|
||||
@@ -318,12 +349,12 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
logger.LegacyPrintf("setup", "failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -331,13 +362,27 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check if admin already exists
|
||||
var count int64
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&count); err != nil {
|
||||
return err
|
||||
var totalUsers int64
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil // Admin already exists
|
||||
var adminUsers int64
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
decision := decideAdminBootstrap(totalUsers, adminUsers)
|
||||
if !decision.shouldCreate {
|
||||
return false, decision.reason, nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.Admin.Password) == "" {
|
||||
password, genErr := generateSecret(16)
|
||||
if genErr != nil {
|
||||
return false, "", fmt.Errorf("failed to generate admin password: %w", genErr)
|
||||
}
|
||||
cfg.Admin.Password = password
|
||||
fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
|
||||
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
admin := &service.User{
|
||||
@@ -351,7 +396,7 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
}
|
||||
|
||||
if err := admin.SetPassword(cfg.Admin.Password); err != nil {
|
||||
return err
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(
|
||||
@@ -367,7 +412,10 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
admin.CreatedAt,
|
||||
admin.UpdatedAt,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
return true, decision.reason, nil
|
||||
}
|
||||
|
||||
func writeConfigFile(cfg *SetupConfig) error {
|
||||
@@ -476,8 +524,8 @@ func getEnvIntOrDefault(key string, defaultValue int) int {
|
||||
// AutoSetupFromEnv performs automatic setup using environment variables
|
||||
// This is designed for Docker deployment where all config is passed via env vars
|
||||
func AutoSetupFromEnv() error {
|
||||
log.Println("Auto setup enabled, configuring from environment variables...")
|
||||
log.Printf("Data directory: %s", GetDataDir())
|
||||
logger.LegacyPrintf("setup", "%s", "Auto setup enabled, configuring from environment variables...")
|
||||
logger.LegacyPrintf("setup", "Data directory: %s", GetDataDir())
|
||||
|
||||
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
|
||||
tz := getEnvOrDefault("TZ", "")
|
||||
@@ -525,61 +573,62 @@ func AutoSetupFromEnv() error {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Generate admin password if not provided
|
||||
if cfg.Admin.Password == "" {
|
||||
password, err := generateSecret(16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||
}
|
||||
cfg.Admin.Password = password
|
||||
fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
|
||||
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
logger.LegacyPrintf("setup", "%s", "Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Test database connection
|
||||
log.Println("Testing database connection...")
|
||||
logger.LegacyPrintf("setup", "%s", "Testing database connection...")
|
||||
if err := TestDatabaseConnection(&cfg.Database); err != nil {
|
||||
return fmt.Errorf("database connection failed: %w", err)
|
||||
}
|
||||
log.Println("Database connection successful")
|
||||
logger.LegacyPrintf("setup", "%s", "Database connection successful")
|
||||
|
||||
// Test Redis connection
|
||||
log.Println("Testing Redis connection...")
|
||||
logger.LegacyPrintf("setup", "%s", "Testing Redis connection...")
|
||||
if err := TestRedisConnection(&cfg.Redis); err != nil {
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
log.Println("Redis connection successful")
|
||||
logger.LegacyPrintf("setup", "%s", "Redis connection successful")
|
||||
|
||||
// Initialize database
|
||||
log.Println("Initializing database...")
|
||||
logger.LegacyPrintf("setup", "%s", "Initializing database...")
|
||||
if err := initializeDatabase(cfg); err != nil {
|
||||
return fmt.Errorf("database initialization failed: %w", err)
|
||||
}
|
||||
log.Println("Database initialized successfully")
|
||||
logger.LegacyPrintf("setup", "%s", "Database initialized successfully")
|
||||
|
||||
// Create admin user
|
||||
log.Println("Creating admin user...")
|
||||
if err := createAdminUser(cfg); err != nil {
|
||||
logger.LegacyPrintf("setup", "%s", "Creating admin user...")
|
||||
created, reason, err := createAdminUser(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("admin user creation failed: %w", err)
|
||||
}
|
||||
log.Printf("Admin user created: %s", cfg.Admin.Email)
|
||||
if created {
|
||||
logger.LegacyPrintf("setup", "Admin user created: %s", cfg.Admin.Email)
|
||||
} else {
|
||||
switch reason {
|
||||
case adminBootstrapReasonAdminExists:
|
||||
logger.LegacyPrintf("setup", "%s", "Admin user already exists, skipping admin bootstrap")
|
||||
case adminBootstrapReasonUsersExistWithoutAdmin:
|
||||
logger.LegacyPrintf("setup", "%s", "Database already has user data; skipping auto admin bootstrap to avoid password overwrite")
|
||||
default:
|
||||
logger.LegacyPrintf("setup", "%s", "Admin bootstrap skipped")
|
||||
}
|
||||
}
|
||||
|
||||
// Write config file
|
||||
log.Println("Writing configuration file...")
|
||||
logger.LegacyPrintf("setup", "%s", "Writing configuration file...")
|
||||
if err := writeConfigFile(cfg); err != nil {
|
||||
return fmt.Errorf("config file creation failed: %w", err)
|
||||
}
|
||||
log.Println("Configuration file created")
|
||||
logger.LegacyPrintf("setup", "%s", "Configuration file created")
|
||||
|
||||
// Create installation lock file
|
||||
if err := createInstallLock(); err != nil {
|
||||
return fmt.Errorf("failed to create install lock: %w", err)
|
||||
}
|
||||
log.Println("Installation lock created")
|
||||
logger.LegacyPrintf("setup", "%s", "Installation lock created")
|
||||
|
||||
log.Println("Auto setup completed successfully!")
|
||||
logger.LegacyPrintf("setup", "%s", "Auto setup completed successfully!")
|
||||
return nil
|
||||
}
|
||||
|
||||
51
backend/internal/setup/setup_test.go
Normal file
51
backend/internal/setup/setup_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package setup
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDecideAdminBootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
totalUsers int64
|
||||
adminUsers int64
|
||||
should bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "empty database should create admin",
|
||||
totalUsers: 0,
|
||||
adminUsers: 0,
|
||||
should: true,
|
||||
reason: adminBootstrapReasonEmptyDatabase,
|
||||
},
|
||||
{
|
||||
name: "admin exists should skip",
|
||||
totalUsers: 10,
|
||||
adminUsers: 1,
|
||||
should: false,
|
||||
reason: adminBootstrapReasonAdminExists,
|
||||
},
|
||||
{
|
||||
name: "users exist without admin should skip",
|
||||
totalUsers: 5,
|
||||
adminUsers: 0,
|
||||
should: false,
|
||||
reason: adminBootstrapReasonUsersExistWithoutAdmin,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := decideAdminBootstrap(tc.totalUsers, tc.adminUsers)
|
||||
if got.shouldCreate != tc.should {
|
||||
t.Fatalf("shouldCreate=%v, want %v", got.shouldCreate, tc.should)
|
||||
}
|
||||
if got.reason != tc.reason {
|
||||
t.Fatalf("reason=%q, want %q", got.reason, tc.reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
55
backend/migrations/054_ops_system_logs.sql
Normal file
55
backend/migrations/054_ops_system_logs.sql
Normal file
@@ -0,0 +1,55 @@
|
||||
-- 054_ops_system_logs.sql
|
||||
-- 统一日志索引表与清理审计表
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ops_system_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
level VARCHAR(16) NOT NULL,
|
||||
component VARCHAR(128) NOT NULL DEFAULT '',
|
||||
message TEXT NOT NULL,
|
||||
request_id VARCHAR(128),
|
||||
client_request_id VARCHAR(128),
|
||||
user_id BIGINT,
|
||||
account_id BIGINT,
|
||||
platform VARCHAR(32),
|
||||
model VARCHAR(128),
|
||||
extra JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_created_at_id
|
||||
ON ops_system_logs (created_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_level_created_at
|
||||
ON ops_system_logs (level, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_component_created_at
|
||||
ON ops_system_logs (component, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_request_id
|
||||
ON ops_system_logs (request_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_client_request_id
|
||||
ON ops_system_logs (client_request_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_user_id_created_at
|
||||
ON ops_system_logs (user_id, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_account_id_created_at
|
||||
ON ops_system_logs (account_id, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_platform_model_created_at
|
||||
ON ops_system_logs (platform, model, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_logs_message_search
|
||||
ON ops_system_logs USING GIN (to_tsvector('simple', COALESCE(message, '')));
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ops_system_log_cleanup_audits (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
operator_id BIGINT NOT NULL,
|
||||
conditions JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
deleted_rows BIGINT NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ops_system_log_cleanup_audits_created_at
|
||||
ON ops_system_log_cleanup_audits (created_at DESC, id DESC);
|
||||
@@ -20,6 +20,52 @@ SERVER_PORT=8080
|
||||
# Server mode: release or debug
|
||||
SERVER_MODE=release
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging Configuration
|
||||
# 日志配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 日志级别:debug/info/warn/error
|
||||
LOG_LEVEL=info
|
||||
# 日志格式:json/console
|
||||
LOG_FORMAT=json
|
||||
# 每条日志附带的 service 字段
|
||||
LOG_SERVICE_NAME=sub2api
|
||||
# 每条日志附带的 env 字段
|
||||
LOG_ENV=production
|
||||
# 是否输出调用方位置信息
|
||||
LOG_CALLER=true
|
||||
# 堆栈输出阈值:none/error/fatal
|
||||
LOG_STACKTRACE_LEVEL=error
|
||||
|
||||
# 输出开关(建议容器内保持双输出)
|
||||
# 是否输出到 stdout/stderr
|
||||
LOG_OUTPUT_TO_STDOUT=true
|
||||
# 是否输出到文件
|
||||
LOG_OUTPUT_TO_FILE=true
|
||||
# 日志文件路径(留空自动推导):
|
||||
# - 设置 DATA_DIR:${DATA_DIR}/logs/sub2api.log
|
||||
# - 未设置 DATA_DIR:/app/data/logs/sub2api.log
|
||||
LOG_OUTPUT_FILE_PATH=
|
||||
|
||||
# 滚动配置
|
||||
# 单文件最大体积(MB)
|
||||
LOG_ROTATION_MAX_SIZE_MB=100
|
||||
# 保留历史文件数量(0 表示不限制)
|
||||
LOG_ROTATION_MAX_BACKUPS=10
|
||||
# 历史日志保留天数(0 表示不限制)
|
||||
LOG_ROTATION_MAX_AGE_DAYS=7
|
||||
# 是否压缩历史日志
|
||||
LOG_ROTATION_COMPRESS=true
|
||||
# 滚动文件时间戳是否使用本地时间
|
||||
LOG_ROTATION_LOCAL_TIME=true
|
||||
|
||||
# 采样配置(高频重复日志降噪)
|
||||
LOG_SAMPLING_ENABLED=false
|
||||
# 每秒前 N 条日志不采样
|
||||
LOG_SAMPLING_INITIAL=100
|
||||
# 之后每 N 条保留 1 条
|
||||
LOG_SAMPLING_THEREAFTER=100
|
||||
|
||||
# Global max request body size in bytes (default: 100MB)
|
||||
# 全局最大请求体大小(字节,默认 100MB)
|
||||
# Applies to all requests, especially important for h2c first request memory protection
|
||||
|
||||
@@ -286,6 +286,70 @@ gateway:
|
||||
# profile_2:
|
||||
# name: "Custom Profile 2"
|
||||
|
||||
# =============================================================================
|
||||
# Logging Configuration
|
||||
# 日志配置
|
||||
# =============================================================================
|
||||
log:
|
||||
# Log level: debug/info/warn/error
|
||||
# 日志级别:debug/info/warn/error
|
||||
level: "info"
|
||||
# Log format: json/console
|
||||
# 日志格式:json/console
|
||||
format: "console"
|
||||
# Service name field written into each log line
|
||||
# 每条日志都会附带 service 字段
|
||||
service_name: "sub2api"
|
||||
# Environment field written into each log line
|
||||
# 每条日志都会附带 env 字段
|
||||
env: "production"
|
||||
# Include caller information
|
||||
# 是否输出调用方位置信息
|
||||
caller: true
|
||||
# Stacktrace threshold: none/error/fatal
|
||||
# 堆栈输出阈值:none/error/fatal
|
||||
stacktrace_level: "error"
|
||||
output:
|
||||
# Keep stdout/stderr output for container log collection
|
||||
# 保持标准输出用于容器日志采集
|
||||
to_stdout: true
|
||||
# Enable file output (default path auto-derived)
|
||||
# 启用文件输出(默认路径自动推导)
|
||||
to_file: true
|
||||
# Empty means:
|
||||
# - DATA_DIR set: {{DATA_DIR}}/logs/sub2api.log
|
||||
# - otherwise: /app/data/logs/sub2api.log
|
||||
# 留空时:
|
||||
# - 设置 DATA_DIR:{{DATA_DIR}}/logs/sub2api.log
|
||||
# - 否则:/app/data/logs/sub2api.log
|
||||
file_path: ""
|
||||
rotation:
|
||||
# Max file size before rotation (MB)
|
||||
# 单文件滚动阈值(MB)
|
||||
max_size_mb: 100
|
||||
# Number of rotated files to keep (0 means unlimited)
|
||||
# 保留历史文件数量(0 表示不限制)
|
||||
max_backups: 10
|
||||
# Number of days to keep old log files (0 means unlimited)
|
||||
# 历史日志保留天数(0 表示不限制)
|
||||
max_age_days: 7
|
||||
# Compress rotated files
|
||||
# 是否压缩历史日志
|
||||
compress: true
|
||||
# Use local time for timestamp in rotated filename
|
||||
# 滚动文件名时间戳使用本地时区
|
||||
local_time: true
|
||||
sampling:
|
||||
# Enable zap sampler (reduce high-frequency repetitive logs)
|
||||
# 启用 zap 采样(减少高频重复日志)
|
||||
enabled: false
|
||||
# Number of first entries per second to always log
|
||||
# 每秒无采样保留的前 N 条日志
|
||||
initial: 100
|
||||
# Thereafter keep 1 out of N entries per second
|
||||
# 之后每 N 条保留 1 条
|
||||
thereafter: 100
|
||||
|
||||
# =============================================================================
|
||||
# Sora Direct Client Configuration
|
||||
# Sora 直连配置
|
||||
|
||||
@@ -162,6 +162,10 @@ services:
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
# postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。
|
||||
# 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷,
|
||||
# docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
|
||||
@@ -142,6 +142,10 @@ services:
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
# postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。
|
||||
# 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷,
|
||||
# docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
|
||||
@@ -166,6 +166,10 @@ services:
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
# postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。
|
||||
# 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷,
|
||||
# docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
|
||||
@@ -850,6 +850,77 @@ export interface OpsAggregationSettings {
|
||||
aggregation_enabled: boolean
|
||||
}
|
||||
|
||||
export interface OpsRuntimeLogConfig {
|
||||
level: 'debug' | 'info' | 'warn' | 'error'
|
||||
enable_sampling: boolean
|
||||
sampling_initial: number
|
||||
sampling_thereafter: number
|
||||
caller: boolean
|
||||
stacktrace_level: 'none' | 'error' | 'fatal'
|
||||
retention_days: number
|
||||
source?: string
|
||||
updated_at?: string
|
||||
updated_by_user_id?: number
|
||||
}
|
||||
|
||||
export interface OpsSystemLog {
|
||||
id: number
|
||||
created_at: string
|
||||
level: string
|
||||
component: string
|
||||
message: string
|
||||
request_id?: string
|
||||
client_request_id?: string
|
||||
user_id?: number | null
|
||||
account_id?: number | null
|
||||
platform?: string
|
||||
model?: string
|
||||
extra?: Record<string, any>
|
||||
}
|
||||
|
||||
export type OpsSystemLogListResponse = PaginatedResponse<OpsSystemLog>
|
||||
|
||||
export interface OpsSystemLogQuery {
|
||||
page?: number
|
||||
page_size?: number
|
||||
time_range?: '5m' | '30m' | '1h' | '6h' | '24h' | '7d' | '30d'
|
||||
start_time?: string
|
||||
end_time?: string
|
||||
level?: string
|
||||
component?: string
|
||||
request_id?: string
|
||||
client_request_id?: string
|
||||
user_id?: number | null
|
||||
account_id?: number | null
|
||||
platform?: string
|
||||
model?: string
|
||||
q?: string
|
||||
}
|
||||
|
||||
export interface OpsSystemLogCleanupRequest {
|
||||
start_time?: string
|
||||
end_time?: string
|
||||
level?: string
|
||||
component?: string
|
||||
request_id?: string
|
||||
client_request_id?: string
|
||||
user_id?: number | null
|
||||
account_id?: number | null
|
||||
platform?: string
|
||||
model?: string
|
||||
q?: string
|
||||
}
|
||||
|
||||
export interface OpsSystemLogSinkHealth {
|
||||
queue_depth: number
|
||||
queue_capacity: number
|
||||
dropped_count: number
|
||||
write_failed_count: number
|
||||
written_count: number
|
||||
avg_write_delay_ms: number
|
||||
last_error?: string
|
||||
}
|
||||
|
||||
export interface OpsErrorLog {
|
||||
id: number
|
||||
created_at: string
|
||||
@@ -1205,6 +1276,36 @@ export async function updateAlertRuntimeSettings(config: OpsAlertRuntimeSettings
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getRuntimeLogConfig(): Promise<OpsRuntimeLogConfig> {
|
||||
const { data } = await apiClient.get<OpsRuntimeLogConfig>('/admin/ops/runtime/logging')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateRuntimeLogConfig(config: OpsRuntimeLogConfig): Promise<OpsRuntimeLogConfig> {
|
||||
const { data } = await apiClient.put<OpsRuntimeLogConfig>('/admin/ops/runtime/logging', config)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function resetRuntimeLogConfig(): Promise<OpsRuntimeLogConfig> {
|
||||
const { data } = await apiClient.post<OpsRuntimeLogConfig>('/admin/ops/runtime/logging/reset')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listSystemLogs(params: OpsSystemLogQuery): Promise<OpsSystemLogListResponse> {
|
||||
const { data } = await apiClient.get<OpsSystemLogListResponse>('/admin/ops/system-logs', { params })
|
||||
return data
|
||||
}
|
||||
|
||||
export async function cleanupSystemLogs(payload: OpsSystemLogCleanupRequest): Promise<{ deleted: number }> {
|
||||
const { data } = await apiClient.post<{ deleted: number }>('/admin/ops/system-logs/cleanup', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getSystemLogSinkHealth(): Promise<OpsSystemLogSinkHealth> {
|
||||
const { data } = await apiClient.get<OpsSystemLogSinkHealth>('/admin/ops/system-logs/health')
|
||||
return data
|
||||
}
|
||||
|
||||
// Advanced settings (DB-backed)
|
||||
export async function getAdvancedSettings(): Promise<OpsAdvancedSettings> {
|
||||
const { data } = await apiClient.get<OpsAdvancedSettings>('/admin/ops/advanced-settings')
|
||||
@@ -1272,10 +1373,16 @@ export const opsAPI = {
|
||||
updateEmailNotificationConfig,
|
||||
getAlertRuntimeSettings,
|
||||
updateAlertRuntimeSettings,
|
||||
getRuntimeLogConfig,
|
||||
updateRuntimeLogConfig,
|
||||
resetRuntimeLogConfig,
|
||||
getAdvancedSettings,
|
||||
updateAdvancedSettings,
|
||||
getMetricThresholds,
|
||||
updateMetricThresholds
|
||||
updateMetricThresholds,
|
||||
listSystemLogs,
|
||||
cleanupSystemLogs,
|
||||
getSystemLogSinkHealth
|
||||
}
|
||||
|
||||
export default opsAPI
|
||||
|
||||
@@ -96,6 +96,13 @@
|
||||
<!-- Alert Events -->
|
||||
<OpsAlertEventsCard v-if="opsEnabled && !(loading && !hasLoadedOnce)" />
|
||||
|
||||
<!-- System Logs -->
|
||||
<OpsSystemLogTable
|
||||
v-if="opsEnabled && !(loading && !hasLoadedOnce)"
|
||||
:platform-filter="platform"
|
||||
:refresh-token="dashboardRefreshToken"
|
||||
/>
|
||||
|
||||
<!-- Settings Dialog (hidden in fullscreen mode) -->
|
||||
<template v-if="!isFullscreen">
|
||||
<OpsSettingsDialog :show="showSettingsDialog" @close="showSettingsDialog = false" @saved="onSettingsSaved" />
|
||||
@@ -158,6 +165,7 @@ import OpsThroughputTrendChart from './components/OpsThroughputTrendChart.vue'
|
||||
import OpsSwitchRateTrendChart from './components/OpsSwitchRateTrendChart.vue'
|
||||
import OpsAlertEventsCard from './components/OpsAlertEventsCard.vue'
|
||||
import OpsOpenAITokenStatsCard from './components/OpsOpenAITokenStatsCard.vue'
|
||||
import OpsSystemLogTable from './components/OpsSystemLogTable.vue'
|
||||
import OpsRequestDetailsModal, { type OpsRequestDetailsPreset } from './components/OpsRequestDetailsModal.vue'
|
||||
import OpsSettingsDialog from './components/OpsSettingsDialog.vue'
|
||||
import OpsAlertRulesCard from './components/OpsAlertRulesCard.vue'
|
||||
|
||||
@@ -130,8 +130,7 @@ watch(
|
||||
next.viewMode !== prev.viewMode ||
|
||||
next.pageSize !== prev.pageSize ||
|
||||
next.platform !== prev.platform ||
|
||||
next.groupId !== prev.groupId ||
|
||||
next.refreshToken !== prev.refreshToken
|
||||
next.groupId !== prev.groupId
|
||||
|
||||
if (next.viewMode === 'pagination' && filtersChanged && next.page !== 1) {
|
||||
page.value = 1
|
||||
|
||||
506
frontend/src/views/admin/ops/components/OpsSystemLogTable.vue
Normal file
506
frontend/src/views/admin/ops/components/OpsSystemLogTable.vue
Normal file
@@ -0,0 +1,506 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, reactive, ref, watch } from 'vue'
|
||||
import { opsAPI, type OpsRuntimeLogConfig, type OpsSystemLog, type OpsSystemLogSinkHealth } from '@/api/admin/ops'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import { useAppStore } from '@/stores'
|
||||
|
||||
const appStore = useAppStore()
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
platformFilter?: string
|
||||
refreshToken?: number
|
||||
}>(), {
|
||||
platformFilter: '',
|
||||
refreshToken: 0
|
||||
})
|
||||
|
||||
const loading = ref(false)
|
||||
const logs = ref<OpsSystemLog[]>([])
|
||||
const total = ref(0)
|
||||
const page = ref(1)
|
||||
const pageSize = ref(20)
|
||||
|
||||
const health = ref<OpsSystemLogSinkHealth>({
|
||||
queue_depth: 0,
|
||||
queue_capacity: 0,
|
||||
dropped_count: 0,
|
||||
write_failed_count: 0,
|
||||
written_count: 0,
|
||||
avg_write_delay_ms: 0
|
||||
})
|
||||
|
||||
const runtimeLoading = ref(false)
|
||||
const runtimeSaving = ref(false)
|
||||
const runtimeConfig = reactive<OpsRuntimeLogConfig>({
|
||||
level: 'info',
|
||||
enable_sampling: false,
|
||||
sampling_initial: 100,
|
||||
sampling_thereafter: 100,
|
||||
caller: true,
|
||||
stacktrace_level: 'error',
|
||||
retention_days: 30
|
||||
})
|
||||
|
||||
const filters = reactive({
|
||||
time_range: '1h' as '5m' | '30m' | '1h' | '6h' | '24h' | '7d' | '30d',
|
||||
start_time: '',
|
||||
end_time: '',
|
||||
level: '',
|
||||
component: '',
|
||||
request_id: '',
|
||||
client_request_id: '',
|
||||
user_id: '',
|
||||
account_id: '',
|
||||
platform: '',
|
||||
model: '',
|
||||
q: ''
|
||||
})
|
||||
|
||||
const levelBadgeClass = (level: string) => {
|
||||
const v = String(level || '').toLowerCase()
|
||||
if (v === 'error' || v === 'fatal') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
|
||||
if (v === 'warn' || v === 'warning') return 'bg-amber-100 text-amber-700 dark:bg-amber-900/30 dark:text-amber-300'
|
||||
if (v === 'debug') return 'bg-slate-100 text-slate-700 dark:bg-slate-800 dark:text-slate-300'
|
||||
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300'
|
||||
}
|
||||
|
||||
const formatTime = (value: string) => {
|
||||
if (!value) return '-'
|
||||
const d = new Date(value)
|
||||
if (Number.isNaN(d.getTime())) return value
|
||||
return d.toLocaleString()
|
||||
}
|
||||
|
||||
const getExtraString = (extra: Record<string, any> | undefined, key: string) => {
|
||||
if (!extra) return ''
|
||||
const v = extra[key]
|
||||
if (v == null) return ''
|
||||
if (typeof v === 'string') return v.trim()
|
||||
if (typeof v === 'number' || typeof v === 'boolean') return String(v)
|
||||
return ''
|
||||
}
|
||||
|
||||
const formatSystemLogDetail = (row: OpsSystemLog) => {
|
||||
const parts: string[] = []
|
||||
const msg = String(row.message || '').trim()
|
||||
if (msg) parts.push(msg)
|
||||
|
||||
const extra = row.extra || {}
|
||||
const statusCode = getExtraString(extra, 'status_code')
|
||||
const latencyMs = getExtraString(extra, 'latency_ms')
|
||||
const method = getExtraString(extra, 'method')
|
||||
const path = getExtraString(extra, 'path')
|
||||
const clientIP = getExtraString(extra, 'client_ip')
|
||||
const protocol = getExtraString(extra, 'protocol')
|
||||
|
||||
const accessParts: string[] = []
|
||||
if (statusCode) accessParts.push(`status=${statusCode}`)
|
||||
if (latencyMs) accessParts.push(`latency_ms=${latencyMs}`)
|
||||
if (method) accessParts.push(`method=${method}`)
|
||||
if (path) accessParts.push(`path=${path}`)
|
||||
if (clientIP) accessParts.push(`ip=${clientIP}`)
|
||||
if (protocol) accessParts.push(`proto=${protocol}`)
|
||||
if (accessParts.length > 0) parts.push(accessParts.join(' '))
|
||||
|
||||
const corrParts: string[] = []
|
||||
if (row.request_id) corrParts.push(`req=${row.request_id}`)
|
||||
if (row.client_request_id) corrParts.push(`client_req=${row.client_request_id}`)
|
||||
if (row.user_id != null) corrParts.push(`user=${row.user_id}`)
|
||||
if (row.account_id != null) corrParts.push(`acc=${row.account_id}`)
|
||||
if (row.platform) corrParts.push(`platform=${row.platform}`)
|
||||
if (row.model) corrParts.push(`model=${row.model}`)
|
||||
if (corrParts.length > 0) parts.push(corrParts.join(' '))
|
||||
|
||||
const errors = getExtraString(extra, 'errors')
|
||||
if (errors) parts.push(`errors=${errors}`)
|
||||
const err = getExtraString(extra, 'err') || getExtraString(extra, 'error')
|
||||
if (err) parts.push(`error=${err}`)
|
||||
|
||||
// 用空格拼接,交给 CSS 自动换行,尽量“填满再换行”。
|
||||
return parts.join(' ')
|
||||
}
|
||||
|
||||
const toRFC3339 = (value: string) => {
|
||||
if (!value) return undefined
|
||||
const d = new Date(value)
|
||||
if (Number.isNaN(d.getTime())) return undefined
|
||||
return d.toISOString()
|
||||
}
|
||||
|
||||
const buildQuery = () => {
|
||||
const query: Record<string, any> = {
|
||||
page: page.value,
|
||||
page_size: pageSize.value,
|
||||
time_range: filters.time_range
|
||||
}
|
||||
|
||||
if (filters.time_range === '30d') {
|
||||
query.time_range = '30d'
|
||||
}
|
||||
if (filters.start_time) query.start_time = toRFC3339(filters.start_time)
|
||||
if (filters.end_time) query.end_time = toRFC3339(filters.end_time)
|
||||
if (filters.level.trim()) query.level = filters.level.trim()
|
||||
if (filters.component.trim()) query.component = filters.component.trim()
|
||||
if (filters.request_id.trim()) query.request_id = filters.request_id.trim()
|
||||
if (filters.client_request_id.trim()) query.client_request_id = filters.client_request_id.trim()
|
||||
if (filters.user_id.trim()) {
|
||||
const v = Number.parseInt(filters.user_id.trim(), 10)
|
||||
if (Number.isFinite(v) && v > 0) query.user_id = v
|
||||
}
|
||||
if (filters.account_id.trim()) {
|
||||
const v = Number.parseInt(filters.account_id.trim(), 10)
|
||||
if (Number.isFinite(v) && v > 0) query.account_id = v
|
||||
}
|
||||
if (filters.platform.trim()) query.platform = filters.platform.trim()
|
||||
if (filters.model.trim()) query.model = filters.model.trim()
|
||||
if (filters.q.trim()) query.q = filters.q.trim()
|
||||
return query
|
||||
}
|
||||
|
||||
const fetchLogs = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await opsAPI.listSystemLogs(buildQuery())
|
||||
logs.value = res.items || []
|
||||
total.value = res.total || 0
|
||||
} catch (err: any) {
|
||||
console.error('[OpsSystemLogTable] Failed to fetch logs', err)
|
||||
appStore.showError(err?.response?.data?.detail || '系统日志加载失败')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const fetchHealth = async () => {
|
||||
try {
|
||||
health.value = await opsAPI.getSystemLogSinkHealth()
|
||||
} catch {
|
||||
// 忽略健康数据读取失败,不影响主流程。
|
||||
}
|
||||
}
|
||||
|
||||
const loadRuntimeConfig = async () => {
|
||||
runtimeLoading.value = true
|
||||
try {
|
||||
const cfg = await opsAPI.getRuntimeLogConfig()
|
||||
runtimeConfig.level = cfg.level
|
||||
runtimeConfig.enable_sampling = cfg.enable_sampling
|
||||
runtimeConfig.sampling_initial = cfg.sampling_initial
|
||||
runtimeConfig.sampling_thereafter = cfg.sampling_thereafter
|
||||
runtimeConfig.caller = cfg.caller
|
||||
runtimeConfig.stacktrace_level = cfg.stacktrace_level
|
||||
runtimeConfig.retention_days = cfg.retention_days
|
||||
} catch (err: any) {
|
||||
console.error('[OpsSystemLogTable] Failed to load runtime log config', err)
|
||||
} finally {
|
||||
runtimeLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const saveRuntimeConfig = async () => {
|
||||
runtimeSaving.value = true
|
||||
try {
|
||||
const saved = await opsAPI.updateRuntimeLogConfig({ ...runtimeConfig })
|
||||
runtimeConfig.level = saved.level
|
||||
runtimeConfig.enable_sampling = saved.enable_sampling
|
||||
runtimeConfig.sampling_initial = saved.sampling_initial
|
||||
runtimeConfig.sampling_thereafter = saved.sampling_thereafter
|
||||
runtimeConfig.caller = saved.caller
|
||||
runtimeConfig.stacktrace_level = saved.stacktrace_level
|
||||
runtimeConfig.retention_days = saved.retention_days
|
||||
appStore.showSuccess('日志运行时配置已生效')
|
||||
} catch (err: any) {
|
||||
console.error('[OpsSystemLogTable] Failed to save runtime log config', err)
|
||||
appStore.showError(err?.response?.data?.detail || '保存日志配置失败')
|
||||
} finally {
|
||||
runtimeSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const resetRuntimeConfig = async () => {
|
||||
const ok = window.confirm('确认回滚为启动配置(env/yaml)并立即生效?')
|
||||
if (!ok) return
|
||||
|
||||
runtimeSaving.value = true
|
||||
try {
|
||||
const saved = await opsAPI.resetRuntimeLogConfig()
|
||||
runtimeConfig.level = saved.level
|
||||
runtimeConfig.enable_sampling = saved.enable_sampling
|
||||
runtimeConfig.sampling_initial = saved.sampling_initial
|
||||
runtimeConfig.sampling_thereafter = saved.sampling_thereafter
|
||||
runtimeConfig.caller = saved.caller
|
||||
runtimeConfig.stacktrace_level = saved.stacktrace_level
|
||||
runtimeConfig.retention_days = saved.retention_days
|
||||
appStore.showSuccess('已回滚到启动日志配置')
|
||||
await fetchHealth()
|
||||
} catch (err: any) {
|
||||
console.error('[OpsSystemLogTable] Failed to reset runtime log config', err)
|
||||
appStore.showError(err?.response?.data?.detail || '回滚日志配置失败')
|
||||
} finally {
|
||||
runtimeSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const cleanupCurrentFilter = async () => {
|
||||
const ok = window.confirm('确认按当前筛选条件清理系统日志?该操作不可撤销。')
|
||||
if (!ok) return
|
||||
try {
|
||||
const payload = {
|
||||
start_time: toRFC3339(filters.start_time),
|
||||
end_time: toRFC3339(filters.end_time),
|
||||
level: filters.level.trim() || undefined,
|
||||
component: filters.component.trim() || undefined,
|
||||
request_id: filters.request_id.trim() || undefined,
|
||||
client_request_id: filters.client_request_id.trim() || undefined,
|
||||
user_id: filters.user_id.trim() ? Number.parseInt(filters.user_id.trim(), 10) : undefined,
|
||||
account_id: filters.account_id.trim() ? Number.parseInt(filters.account_id.trim(), 10) : undefined,
|
||||
platform: filters.platform.trim() || undefined,
|
||||
model: filters.model.trim() || undefined,
|
||||
q: filters.q.trim() || undefined
|
||||
}
|
||||
const res = await opsAPI.cleanupSystemLogs(payload)
|
||||
appStore.showSuccess(`清理完成,删除 ${res.deleted || 0} 条日志`)
|
||||
page.value = 1
|
||||
await Promise.all([fetchLogs(), fetchHealth()])
|
||||
} catch (err: any) {
|
||||
console.error('[OpsSystemLogTable] Failed to cleanup logs', err)
|
||||
appStore.showError(err?.response?.data?.detail || '清理系统日志失败')
|
||||
}
|
||||
}
|
||||
|
||||
const resetFilters = () => {
|
||||
filters.time_range = '1h'
|
||||
filters.start_time = ''
|
||||
filters.end_time = ''
|
||||
filters.level = ''
|
||||
filters.component = ''
|
||||
filters.request_id = ''
|
||||
filters.client_request_id = ''
|
||||
filters.user_id = ''
|
||||
filters.account_id = ''
|
||||
filters.platform = props.platformFilter || ''
|
||||
filters.model = ''
|
||||
filters.q = ''
|
||||
page.value = 1
|
||||
fetchLogs()
|
||||
}
|
||||
|
||||
watch(() => props.platformFilter, (v) => {
|
||||
if (v && !filters.platform) {
|
||||
filters.platform = v
|
||||
page.value = 1
|
||||
fetchLogs()
|
||||
}
|
||||
})
|
||||
|
||||
watch(() => props.refreshToken, () => {
|
||||
fetchLogs()
|
||||
fetchHealth()
|
||||
})
|
||||
|
||||
const onPageChange = (next: number) => {
|
||||
page.value = next
|
||||
fetchLogs()
|
||||
}
|
||||
|
||||
const onPageSizeChange = (next: number) => {
|
||||
pageSize.value = next
|
||||
page.value = 1
|
||||
fetchLogs()
|
||||
}
|
||||
|
||||
const applyFilters = () => {
|
||||
page.value = 1
|
||||
fetchLogs()
|
||||
}
|
||||
|
||||
const hasData = computed(() => logs.value.length > 0)
|
||||
|
||||
onMounted(async () => {
|
||||
if (props.platformFilter) {
|
||||
filters.platform = props.platformFilter
|
||||
}
|
||||
await Promise.all([fetchLogs(), fetchHealth(), loadRuntimeConfig()])
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="rounded-2xl border border-gray-200 bg-white p-4 shadow-sm dark:border-dark-700 dark:bg-dark-900/60">
|
||||
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
|
||||
<div>
|
||||
<h3 class="text-sm font-bold text-gray-900 dark:text-white">系统日志</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">默认按最新时间倒序,支持筛选搜索与按条件清理。</p>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2 text-xs">
|
||||
<span class="rounded-md bg-gray-100 px-2 py-1 text-gray-700 dark:bg-dark-700 dark:text-gray-200">队列 {{ health.queue_depth }}/{{ health.queue_capacity }}</span>
|
||||
<span class="rounded-md bg-gray-100 px-2 py-1 text-gray-700 dark:bg-dark-700 dark:text-gray-200">写入 {{ health.written_count }}</span>
|
||||
<span class="rounded-md bg-amber-100 px-2 py-1 text-amber-700 dark:bg-amber-900/30 dark:text-amber-300">丢弃 {{ health.dropped_count }}</span>
|
||||
<span class="rounded-md bg-red-100 px-2 py-1 text-red-700 dark:bg-red-900/30 dark:text-red-300">失败 {{ health.write_failed_count }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-4 rounded-xl border border-gray-200 bg-gray-50 p-3 dark:border-dark-700 dark:bg-dark-800/70">
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
<div class="text-xs font-semibold text-gray-700 dark:text-gray-200">运行时日志配置(实时生效)</div>
|
||||
<span v-if="runtimeLoading" class="text-xs text-gray-500">加载中...</span>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-6">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="runtimeConfig.level" class="input mt-1">
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
堆栈阈值
|
||||
<select v-model="runtimeConfig.stacktrace_level" class="input mt-1">
|
||||
<option value="none">none</option>
|
||||
<option value="error">error</option>
|
||||
<option value="fatal">fatal</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
采样初始
|
||||
<input v-model.number="runtimeConfig.sampling_initial" type="number" min="1" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
采样后续
|
||||
<input v-model.number="runtimeConfig.sampling_thereafter" type="number" min="1" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
保留天数
|
||||
<input v-model.number="runtimeConfig.retention_days" type="number" min="1" max="3650" class="input mt-1" />
|
||||
</label>
|
||||
<div class="flex items-end gap-2">
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.caller" type="checkbox" />
|
||||
caller
|
||||
</label>
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.enable_sampling" type="checkbox" />
|
||||
sampling
|
||||
</label>
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
|
||||
{{ runtimeSaving ? '保存中...' : '保存并生效' }}
|
||||
</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
|
||||
回滚默认值
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<p v-if="health.last_error" class="mt-2 text-xs text-red-600 dark:text-red-400">最近写入错误:{{ health.last_error }}</p>
|
||||
</div>
|
||||
|
||||
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-5">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
时间范围
|
||||
<select v-model="filters.time_range" class="input mt-1">
|
||||
<option value="5m">5m</option>
|
||||
<option value="30m">30m</option>
|
||||
<option value="1h">1h</option>
|
||||
<option value="6h">6h</option>
|
||||
<option value="24h">24h</option>
|
||||
<option value="7d">7d</option>
|
||||
<option value="30d">30d</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
开始时间(可选)
|
||||
<input v-model="filters.start_time" type="datetime-local" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
结束时间(可选)
|
||||
<input v-model="filters.end_time" type="datetime-local" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="filters.level" class="input mt-1">
|
||||
<option value="">全部</option>
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
组件
|
||||
<input v-model="filters.component" type="text" class="input mt-1" placeholder="如 http.access" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
request_id
|
||||
<input v-model="filters.request_id" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
client_request_id
|
||||
<input v-model="filters.client_request_id" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
user_id
|
||||
<input v-model="filters.user_id" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
account_id
|
||||
<input v-model="filters.account_id" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
平台
|
||||
<input v-model="filters.platform" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
模型
|
||||
<input v-model="filters.model" type="text" class="input mt-1" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
关键词
|
||||
<input v-model="filters.q" type="text" class="input mt-1" placeholder="消息/request_id" />
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="mb-3 flex flex-wrap gap-2">
|
||||
<button type="button" class="btn btn-primary btn-sm" @click="applyFilters">查询</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" @click="resetFilters">重置</button>
|
||||
<button type="button" class="btn btn-danger btn-sm" @click="cleanupCurrentFilter">按当前筛选清理</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" @click="fetchHealth">刷新健康指标</button>
|
||||
</div>
|
||||
|
||||
<div class="overflow-hidden rounded-xl border border-gray-200 dark:border-dark-700">
|
||||
<div v-if="loading" class="px-4 py-8 text-center text-sm text-gray-500">加载中...</div>
|
||||
<div v-else-if="!hasData" class="px-4 py-8 text-center text-sm text-gray-500">暂无系统日志</div>
|
||||
<div v-else class="overflow-auto">
|
||||
<table class="min-w-full table-fixed divide-y divide-gray-200 dark:divide-dark-700">
|
||||
<thead class="bg-gray-50 dark:bg-dark-900">
|
||||
<tr>
|
||||
<th class="w-[170px] px-3 py-2 text-left text-[11px] font-semibold text-gray-500">时间</th>
|
||||
<th class="w-[80px] px-3 py-2 text-left text-[11px] font-semibold text-gray-500">级别</th>
|
||||
<th class="px-3 py-2 text-left text-[11px] font-semibold text-gray-500">日志详细信息</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y divide-gray-100 dark:divide-dark-800">
|
||||
<tr v-for="row in logs" :key="row.id" class="align-top">
|
||||
<td class="px-3 py-2 text-xs text-gray-700 dark:text-gray-300">{{ formatTime(row.created_at) }}</td>
|
||||
<td class="px-3 py-2 text-xs">
|
||||
<span class="inline-flex rounded-full px-2 py-0.5 font-semibold" :class="levelBadgeClass(row.level)">
|
||||
{{ row.level }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="px-3 py-2 text-xs text-gray-700 dark:text-gray-300 whitespace-normal break-all">
|
||||
{{ formatSystemLogDetail(row) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<Pagination
|
||||
:total="total"
|
||||
:page="page"
|
||||
:page-size="pageSize"
|
||||
:page-size-options="[10, 20, 50, 100, 200]"
|
||||
@update:page="onPageChange"
|
||||
@update:page-size="onPageSizeChange"
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
@@ -17,5 +17,8 @@ export type {
|
||||
OpsMetricThresholds,
|
||||
OpsAdvancedSettings,
|
||||
OpsDataRetentionSettings,
|
||||
OpsAggregationSettings
|
||||
OpsAggregationSettings,
|
||||
OpsRuntimeLogConfig,
|
||||
OpsSystemLog,
|
||||
OpsSystemLogSinkHealth
|
||||
} from '@/api/admin/ops'
|
||||
|
||||
135
logging_audit_20260212.md
Normal file
135
logging_audit_20260212.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# 日志专项审计与整理(2026-02-12)
|
||||
|
||||
## 1. 全量扫描结论
|
||||
|
||||
- 扫描范围:`backend/` + `frontend/`
|
||||
- 日志相关调用总量(粗统计):约 `4100` 处
|
||||
- 后端标准库日志(`log.Printf/Println/Fatal*`):`808` 处(本轮整改后剩余 `269` 处)
|
||||
- 前端 `console.*`:`180` 处
|
||||
|
||||
关键观察:
|
||||
|
||||
1. 后端大量业务日志仍走标准库 `log`,在当前初始化流程里会被统一当作 `INFO` 输出,导致“错误/告警等级失真”。
|
||||
2. 网关关键链路(OpenAI/Gemini/Sora)原有日志以格式化字符串为主,上下文字段(`request_id/user_id/group_id/model/account_id`)不完整,排障时需要人工拼接上下文。
|
||||
3. Token 刷新服务同时混用 `log` 与 `slog`,同类事件日志风格不一致,不利于检索与聚合。
|
||||
4. 前端 `console.error/warn` 使用量高,缺少统一封装,生产环境噪音和敏感信息泄漏风险较高。
|
||||
|
||||
## 2. 本次已落地整改
|
||||
|
||||
### 2.1 全局层(后端标准库日志分级修复)
|
||||
|
||||
- 修改:`backend/internal/pkg/logger/logger.go`
|
||||
- 结果:
|
||||
1. 替换原 `zap.RedirectStdLogAt(..., INFO)` 机制,改为自定义 `stdlog bridge`。
|
||||
2. 对标准库日志自动推断等级(`DEBUG/WARN/ERROR/INFO`),并打上 `legacy_stdlog=true` 标记。
|
||||
3. 规范化消息文本(去换行、压缩空白),提升可读性和检索稳定性。
|
||||
4. 调整初始化顺序:先桥接 `slog`,再桥接 `stdlog`,避免 `slog.SetDefault` 覆盖标准库桥接。
|
||||
5. 新增 `logger.LegacyPrintf(component, format, ...args)`,用于后端历史 `printf` 日志的平滑迁移,自动推断等级并打 `legacy_printf=true` 标记。
|
||||
|
||||
### 2.2 核心请求链路结构化改造
|
||||
|
||||
- 新增:`backend/internal/handler/logging.go`
|
||||
- 统一提供请求级 logger 获取入口,继承中间件注入的 `request_id` 上下文。
|
||||
|
||||
- 改造文件:
|
||||
- `backend/internal/handler/gateway_handler.go`
|
||||
- `backend/internal/handler/openai_gateway_handler.go`
|
||||
- `backend/internal/handler/gemini_v1beta_handler.go`
|
||||
- `backend/internal/handler/sora_gateway_handler.go`
|
||||
- `backend/internal/service/antigravity_gateway_service.go`
|
||||
- `backend/internal/service/gateway_service.go`
|
||||
- `backend/internal/service/gemini_oauth_service.go`
|
||||
- `backend/internal/service/auth_service.go`
|
||||
- `backend/internal/setup/setup.go`
|
||||
- `backend/internal/service/usage_cleanup_service.go`
|
||||
- `backend/internal/service/pricing_service.go`
|
||||
- `backend/internal/repository/account_repo.go`
|
||||
- `backend/internal/service/openai_gateway_service.go`
|
||||
- `backend/internal/service/scheduler_snapshot_service.go`
|
||||
- `backend/internal/service/gemini_messages_compat_service.go`
|
||||
- `backend/internal/service/dashboard_aggregation_service.go`
|
||||
- `backend/internal/service/billing_cache_service.go`
|
||||
- `backend/internal/repository/claude_oauth_service.go`
|
||||
- `backend/internal/service/admin_service.go`
|
||||
- `backend/internal/handler/admin/ops_ws_handler.go`
|
||||
|
||||
- 改造内容:
|
||||
1. 把关键日志从字符串拼接改为结构化字段。
|
||||
2. 统一带上 `component/user_id/api_key_id/group_id/model/account_id` 等字段。
|
||||
3. 按语义拆分等级:
|
||||
- 预期业务拒绝(如账单校验失败、队列满)使用 `Info`
|
||||
- 降级路径/可恢复异常(如抢槽失败、粘性会话绑定失败)使用 `Warn`
|
||||
- 真正故障(如转发失败、使用量记录失败)使用 `Error`
|
||||
4. 新增请求完成日志(`*.request_completed`)用于链路闭环追踪。
|
||||
5. 对高密度 `log.Printf` 完成批量迁移到 `logger.LegacyPrintf`(本轮累计 511 处),并统一组件字段:
|
||||
- `component=service.antigravity_gateway`
|
||||
- `component=service.gateway`
|
||||
- `component=service.gemini_oauth`
|
||||
- `component=service.auth`
|
||||
- `component=setup`
|
||||
- `component=service.usage_cleanup`
|
||||
- `component=service.pricing`
|
||||
- `component=repository.account`
|
||||
- `component=service.openai_gateway`
|
||||
- `component=service.scheduler_snapshot`
|
||||
- `component=service.gemini_messages_compat`
|
||||
- `component=service.dashboard_aggregation`
|
||||
- `component=service.billing_cache`
|
||||
- `component=repository.claude_oauth`
|
||||
- `component=service.admin`
|
||||
- `component=handler.admin.ops_ws`
|
||||
6. OpenAI 透传断流相关两条关键告警统一回到新日志系统输出(`service.openai_gateway`),并通过兼容逻辑保证测试环境可捕获。
|
||||
|
||||
### 2.3 后台任务日志统一
|
||||
|
||||
- 改造:`backend/internal/service/token_refresh_service.go`
|
||||
- 结果:
|
||||
1. 统一改为 `slog` 结构化输出。
|
||||
2. `retry/cycle/account` 等事件改为字段化日志,便于按账号和批次检索。
|
||||
3. 对“无实际刷新活动”的周期日志降级到 `Debug`,减少噪音。
|
||||
|
||||
### 2.4 测试保障
|
||||
|
||||
- 新增:`backend/internal/pkg/logger/stdlog_bridge_test.go`
|
||||
- 覆盖标准库日志等级推断、消息标准化、输出路由行为。
|
||||
- 已验证:
|
||||
- `go test ./internal/pkg/logger ./internal/handler ./internal/service` 通过。
|
||||
|
||||
## 3. 仍需继续整改(建议下一批)
|
||||
|
||||
### 3.1 后端剩余 `std log` 高密度区域(优先级 P1)
|
||||
|
||||
建议优先处理以下文件(调用量高):
|
||||
|
||||
1. `backend/internal/service/usage_cleanup_service.go`(26)
|
||||
2. `backend/internal/service/pricing_service.go`(26)
|
||||
3. `backend/internal/repository/account_repo.go`(24)
|
||||
4. `backend/internal/service/openai_gateway_service.go`(23)
|
||||
5. `backend/internal/service/scheduler_snapshot_service.go`(20)
|
||||
|
||||
(以上已完成。当前 Top 5 已变为:`backend/cmd/server/main.go`、`backend/internal/service/openai_tool_corrector.go`、`backend/internal/service/email_queue_service.go`、`backend/internal/config/config.go`、`backend/internal/service/ops_cleanup_service.go`)
|
||||
|
||||
目标:逐步替换为结构化日志,减少对 `legacy_stdlog` 兼容桥接的依赖。
|
||||
|
||||
### 3.2 前端日志治理(优先级 P1)
|
||||
|
||||
建议新增统一前端日志工具(如 `src/utils/logger.ts`)并分三步替换:
|
||||
|
||||
1. `console.error/warn/debug/log` 全部收敛到统一 API;
|
||||
2. 生产环境默认降噪(仅保留关键告警/错误);
|
||||
3. 统一字段(模块名、请求ID、用户ID、路由、错误码)并避免打印敏感数据。
|
||||
|
||||
### 3.3 日志规范与门禁(优先级 P2)
|
||||
|
||||
建议补充:
|
||||
|
||||
1. 日志规范文档(等级定义、字段最小集、脱敏要求);
|
||||
2. CI 检查规则:限制新增裸 `log.Printf` / `console.*`;
|
||||
3. 面向运营告警的事件白名单(例如 `*.forward_failed`、`*.retry_exhausted*`)。
|
||||
|
||||
## 4. 本次整理后可直接使用的检索建议
|
||||
|
||||
1. 过滤历史兼容日志:`legacy_stdlog=true`
|
||||
2. 网关入口故障:`component=handler.* AND level in (WARN,ERROR)`
|
||||
3. 请求闭环:按 `request_id` + `*.request_completed` + `*.forward_failed`
|
||||
4. token 刷新故障:`component=*token_refresh* AND (retry_attempt_failed OR set_error_status_failed)`
|
||||
Reference in New Issue
Block a user