mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-11 18:44:45 +08:00
284 lines
9.5 KiB
Go
284 lines
9.5 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"sub2api/internal/model"
|
|||
|
|
"sub2api/internal/repository"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
var (
|
|||
|
|
ErrUsageLogNotFound = errors.New("usage log not found")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// CreateUsageLogRequest 创建使用日志请求
|
|||
|
|
type CreateUsageLogRequest struct {
|
|||
|
|
UserID int64 `json:"user_id"`
|
|||
|
|
ApiKeyID int64 `json:"api_key_id"`
|
|||
|
|
AccountID int64 `json:"account_id"`
|
|||
|
|
RequestID string `json:"request_id"`
|
|||
|
|
Model string `json:"model"`
|
|||
|
|
InputTokens int `json:"input_tokens"`
|
|||
|
|
OutputTokens int `json:"output_tokens"`
|
|||
|
|
CacheCreationTokens int `json:"cache_creation_tokens"`
|
|||
|
|
CacheReadTokens int `json:"cache_read_tokens"`
|
|||
|
|
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
|||
|
|
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
|||
|
|
InputCost float64 `json:"input_cost"`
|
|||
|
|
OutputCost float64 `json:"output_cost"`
|
|||
|
|
CacheCreationCost float64 `json:"cache_creation_cost"`
|
|||
|
|
CacheReadCost float64 `json:"cache_read_cost"`
|
|||
|
|
TotalCost float64 `json:"total_cost"`
|
|||
|
|
ActualCost float64 `json:"actual_cost"`
|
|||
|
|
RateMultiplier float64 `json:"rate_multiplier"`
|
|||
|
|
Stream bool `json:"stream"`
|
|||
|
|
DurationMs *int `json:"duration_ms"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UsageStats 使用统计
|
|||
|
|
type UsageStats struct {
|
|||
|
|
TotalRequests int64 `json:"total_requests"`
|
|||
|
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
|||
|
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
|||
|
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
|||
|
|
TotalTokens int64 `json:"total_tokens"`
|
|||
|
|
TotalCost float64 `json:"total_cost"`
|
|||
|
|
TotalActualCost float64 `json:"total_actual_cost"`
|
|||
|
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UsageService 使用统计服务
|
|||
|
|
type UsageService struct {
|
|||
|
|
usageRepo *repository.UsageLogRepository
|
|||
|
|
userRepo *repository.UserRepository
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewUsageService 创建使用统计服务实例
|
|||
|
|
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
|
|||
|
|
return &UsageService{
|
|||
|
|
usageRepo: usageRepo,
|
|||
|
|
userRepo: userRepo,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create 创建使用日志
|
|||
|
|
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
|
|||
|
|
// 验证用户存在
|
|||
|
|
_, err := s.userRepo.GetByID(ctx, req.UserID)
|
|||
|
|
if err != nil {
|
|||
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|||
|
|
return nil, ErrUserNotFound
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("get user: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建使用日志
|
|||
|
|
usageLog := &model.UsageLog{
|
|||
|
|
UserID: req.UserID,
|
|||
|
|
ApiKeyID: req.ApiKeyID,
|
|||
|
|
AccountID: req.AccountID,
|
|||
|
|
RequestID: req.RequestID,
|
|||
|
|
Model: req.Model,
|
|||
|
|
InputTokens: req.InputTokens,
|
|||
|
|
OutputTokens: req.OutputTokens,
|
|||
|
|
CacheCreationTokens: req.CacheCreationTokens,
|
|||
|
|
CacheReadTokens: req.CacheReadTokens,
|
|||
|
|
CacheCreation5mTokens: req.CacheCreation5mTokens,
|
|||
|
|
CacheCreation1hTokens: req.CacheCreation1hTokens,
|
|||
|
|
InputCost: req.InputCost,
|
|||
|
|
OutputCost: req.OutputCost,
|
|||
|
|
CacheCreationCost: req.CacheCreationCost,
|
|||
|
|
CacheReadCost: req.CacheReadCost,
|
|||
|
|
TotalCost: req.TotalCost,
|
|||
|
|
ActualCost: req.ActualCost,
|
|||
|
|
RateMultiplier: req.RateMultiplier,
|
|||
|
|
Stream: req.Stream,
|
|||
|
|
DurationMs: req.DurationMs,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
|
|||
|
|
return nil, fmt.Errorf("create usage log: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 扣除用户余额
|
|||
|
|
if req.ActualCost > 0 {
|
|||
|
|
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
|
|||
|
|
return nil, fmt.Errorf("update user balance: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return usageLog, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByID 根据ID获取使用日志
|
|||
|
|
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
|||
|
|
log, err := s.usageRepo.GetByID(ctx, id)
|
|||
|
|
if err != nil {
|
|||
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|||
|
|
return nil, ErrUsageLogNotFound
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("get usage log: %w", err)
|
|||
|
|
}
|
|||
|
|
return log, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByUser 获取用户的使用日志列表
|
|||
|
|
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
|||
|
|
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
return logs, pagination, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByApiKey 获取API Key的使用日志列表
|
|||
|
|
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
|||
|
|
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
return logs, pagination, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByAccount 获取账号的使用日志列表
|
|||
|
|
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
|||
|
|
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
return logs, pagination, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStatsByUser 获取用户的使用统计
|
|||
|
|
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
|||
|
|
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return s.calculateStats(logs), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStatsByApiKey 获取API Key的使用统计
|
|||
|
|
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
|||
|
|
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return s.calculateStats(logs), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStatsByAccount 获取账号的使用统计
|
|||
|
|
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
|||
|
|
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return s.calculateStats(logs), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStatsByModel 获取模型的使用统计
|
|||
|
|
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
|
|||
|
|
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return s.calculateStats(logs), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetDailyStats 获取每日使用统计(最近N天)
|
|||
|
|
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
|
|||
|
|
endTime := time.Now()
|
|||
|
|
startTime := endTime.AddDate(0, 0, -days)
|
|||
|
|
|
|||
|
|
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("list usage logs: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 按日期分组统计
|
|||
|
|
dailyStats := make(map[string]*UsageStats)
|
|||
|
|
for _, log := range logs {
|
|||
|
|
dateKey := log.CreatedAt.Format("2006-01-02")
|
|||
|
|
if _, exists := dailyStats[dateKey]; !exists {
|
|||
|
|
dailyStats[dateKey] = &UsageStats{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
stats := dailyStats[dateKey]
|
|||
|
|
stats.TotalRequests++
|
|||
|
|
stats.TotalInputTokens += int64(log.InputTokens)
|
|||
|
|
stats.TotalOutputTokens += int64(log.OutputTokens)
|
|||
|
|
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
|||
|
|
stats.TotalTokens += int64(log.TotalTokens())
|
|||
|
|
stats.TotalCost += log.TotalCost
|
|||
|
|
stats.TotalActualCost += log.ActualCost
|
|||
|
|
|
|||
|
|
if log.DurationMs != nil {
|
|||
|
|
stats.AverageDurationMs += float64(*log.DurationMs)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 计算平均值并转换为数组
|
|||
|
|
result := make([]map[string]interface{}, 0, len(dailyStats))
|
|||
|
|
for date, stats := range dailyStats {
|
|||
|
|
if stats.TotalRequests > 0 {
|
|||
|
|
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
result = append(result, map[string]interface{}{
|
|||
|
|
"date": date,
|
|||
|
|
"total_requests": stats.TotalRequests,
|
|||
|
|
"total_input_tokens": stats.TotalInputTokens,
|
|||
|
|
"total_output_tokens": stats.TotalOutputTokens,
|
|||
|
|
"total_cache_tokens": stats.TotalCacheTokens,
|
|||
|
|
"total_tokens": stats.TotalTokens,
|
|||
|
|
"total_cost": stats.TotalCost,
|
|||
|
|
"total_actual_cost": stats.TotalActualCost,
|
|||
|
|
"average_duration_ms": stats.AverageDurationMs,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// calculateStats 计算统计数据
|
|||
|
|
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
|
|||
|
|
stats := &UsageStats{}
|
|||
|
|
|
|||
|
|
for _, log := range logs {
|
|||
|
|
stats.TotalRequests++
|
|||
|
|
stats.TotalInputTokens += int64(log.InputTokens)
|
|||
|
|
stats.TotalOutputTokens += int64(log.OutputTokens)
|
|||
|
|
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
|||
|
|
stats.TotalTokens += int64(log.TotalTokens())
|
|||
|
|
stats.TotalCost += log.TotalCost
|
|||
|
|
stats.TotalActualCost += log.ActualCost
|
|||
|
|
|
|||
|
|
if log.DurationMs != nil {
|
|||
|
|
stats.AverageDurationMs += float64(*log.DurationMs)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 计算平均持续时间
|
|||
|
|
if stats.TotalRequests > 0 {
|
|||
|
|
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return stats
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Delete 删除使用日志(管理员功能,谨慎使用)
|
|||
|
|
func (s *UsageService) Delete(ctx context.Context, id int64) error {
|
|||
|
|
if err := s.usageRepo.Delete(ctx, id); err != nil {
|
|||
|
|
return fmt.Errorf("delete usage log: %w", err)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|