feat(api-key): 添加 IP 白名单/黑名单限制功能 (#221)

* feat(api-key): add IP whitelist/blacklist restriction and usage log IP tracking

- Add IP restriction feature for API keys (whitelist/blacklist with CIDR support)
- Add IP address logging to usage logs (admin-only visibility)
- Remove billing_type column from usage logs UI (redundant)
- Use generic "Access denied" error message for security

Backend:
- New ip package with IP/CIDR validation and matching utilities
- Database migrations for ip_whitelist, ip_blacklist (api_keys) and ip_address (usage_logs)
- Middleware IP restriction check after API key validation
- Input validation for IP/CIDR patterns on create/update

Frontend:
- API key form with enable toggle for IP restriction
- Shield icon indicator in table for keys with IP restriction
- Removed billing_type filter and column from usage views

* fix: update API contract tests for ip_whitelist/ip_blacklist fields

Add ip_whitelist and ip_blacklist fields to expected JSON responses
in API contract tests to match the new API key schema.
This commit is contained in:
Edric.Li
2026-01-09 21:59:32 +08:00
committed by GitHub
parent 62dc0b953b
commit 0a4641c24e
45 changed files with 1500 additions and 183 deletions

View File

@@ -3,6 +3,7 @@
package ent package ent
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -35,6 +36,10 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field. // Status holds the value of the "status" field.
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
IPBlacklist []string `json:"ip_blacklist,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set. // The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"` Edges APIKeyEdges `json:"edges"`
@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Status = value.String _m.Status = value.String
} }
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
}
}
case apikey.FieldIPBlacklist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("status=") builder.WriteString("status=")
builder.WriteString(_m.Status) builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")
builder.WriteString("ip_blacklist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View File

@@ -31,6 +31,10 @@ const (
FieldGroupID = "group_id" FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database. // FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status" FieldStatus = "status"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
FieldIPBlacklist = "ip_blacklist"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user" EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations. // EdgeGroup holds the string denoting the group edge name in mutations.
@@ -73,6 +77,8 @@ var Columns = []string{
FieldName, FieldName,
FieldGroupID, FieldGroupID,
FieldStatus, FieldStatus,
FieldIPWhitelist,
FieldIPBlacklist,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).

View File

@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
} }
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
}
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
func IPWhitelistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
}
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
func IPBlacklistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
}
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
func IPBlacklistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
}
// HasUser applies the HasEdge predicate on the "user" edge. // HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey { func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) { return predicate.APIKey(func(s *sql.Selector) {

View File

@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c return _c
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
return _c
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
_c.mutation.SetIPBlacklist(v)
return _c
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID) return _c.SetUserID(v.ID)
@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value _node.Status = value
} }
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
}
if value, ok := _c.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
_node.IPBlacklist = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u return u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
return u
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPWhitelist)
return u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPWhitelist)
return u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPBlacklist, v)
return u
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPBlacklist)
return u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPBlacklist)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {

View File

@@ -10,6 +10,7 @@ import (
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
return _u return _u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
} }
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
return _u return _u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
} }
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,

View File

@@ -18,6 +18,8 @@ var (
{Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "key", Type: field.TypeString, Unique: true, Size: 128},
{Name: "name", Type: field.TypeString, Size: 100}, {Name: "name", Type: field.TypeString, Size: 100},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64}, {Name: "user_id", Type: field.TypeInt64},
} }
@@ -29,13 +31,13 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "api_keys_groups_api_keys", Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[7]}, Columns: []*schema.Column{APIKeysColumns[9]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "api_keys_users_api_keys", Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[8]}, Columns: []*schema.Column{APIKeysColumns[10]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
@@ -44,12 +46,12 @@ var (
{ {
Name: "apikey_user_id", Name: "apikey_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[8]}, Columns: []*schema.Column{APIKeysColumns[10]},
}, },
{ {
Name: "apikey_group_id", Name: "apikey_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[7]}, Columns: []*schema.Column{APIKeysColumns[9]},
}, },
{ {
Name: "apikey_status", Name: "apikey_status",
@@ -376,6 +378,7 @@ var (
{Name: "duration_ms", Type: field.TypeInt, Nullable: true}, {Name: "duration_ms", Type: field.TypeInt, Nullable: true},
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, {Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512}, {Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -393,31 +396,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[24]}, Columns: []*schema.Column{UsageLogsColumns[25]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
@@ -426,32 +429,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]}, Columns: []*schema.Column{UsageLogsColumns[25]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[26]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[24]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
@@ -466,12 +469,12 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
}, },
}, },
} }

View File

@@ -54,26 +54,30 @@ const (
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
type APIKeyMutation struct { type APIKeyMutation struct {
config config
op Op op Op
typ string typ string
id *int64 id *int64
created_at *time.Time created_at *time.Time
updated_at *time.Time updated_at *time.Time
deleted_at *time.Time deleted_at *time.Time
key *string key *string
name *string name *string
status *string status *string
clearedFields map[string]struct{} ip_whitelist *[]string
user *int64 appendip_whitelist []string
cleareduser bool ip_blacklist *[]string
group *int64 appendip_blacklist []string
clearedgroup bool clearedFields map[string]struct{}
usage_logs map[int64]struct{} user *int64
removedusage_logs map[int64]struct{} cleareduser bool
clearedusage_logs bool group *int64
done bool clearedgroup bool
oldValue func(context.Context) (*APIKey, error) usage_logs map[int64]struct{}
predicates []predicate.APIKey removedusage_logs map[int64]struct{}
clearedusage_logs bool
done bool
oldValue func(context.Context) (*APIKey, error)
predicates []predicate.APIKey
} }
var _ ent.Mutation = (*APIKeyMutation)(nil) var _ ent.Mutation = (*APIKeyMutation)(nil)
@@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() {
m.status = nil m.status = nil
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (m *APIKeyMutation) SetIPWhitelist(s []string) {
m.ip_whitelist = &s
m.appendip_whitelist = nil
}
// IPWhitelist returns the value of the "ip_whitelist" field in the mutation.
func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) {
v := m.ip_whitelist
if v == nil {
return
}
return *v, true
}
// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPWhitelist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err)
}
return oldValue.IPWhitelist, nil
}
// AppendIPWhitelist adds s to the "ip_whitelist" field.
func (m *APIKeyMutation) AppendIPWhitelist(s []string) {
m.appendip_whitelist = append(m.appendip_whitelist, s...)
}
// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation.
func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) {
if len(m.appendip_whitelist) == 0 {
return nil, false
}
return m.appendip_whitelist, true
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (m *APIKeyMutation) ClearIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
m.clearedFields[apikey.FieldIPWhitelist] = struct{}{}
}
// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation.
func (m *APIKeyMutation) IPWhitelistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPWhitelist]
return ok
}
// ResetIPWhitelist resets all changes to the "ip_whitelist" field.
func (m *APIKeyMutation) ResetIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
delete(m.clearedFields, apikey.FieldIPWhitelist)
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (m *APIKeyMutation) SetIPBlacklist(s []string) {
m.ip_blacklist = &s
m.appendip_blacklist = nil
}
// IPBlacklist returns the value of the "ip_blacklist" field in the mutation.
func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) {
v := m.ip_blacklist
if v == nil {
return
}
return *v, true
}
// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPBlacklist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err)
}
return oldValue.IPBlacklist, nil
}
// AppendIPBlacklist adds s to the "ip_blacklist" field.
func (m *APIKeyMutation) AppendIPBlacklist(s []string) {
m.appendip_blacklist = append(m.appendip_blacklist, s...)
}
// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation.
func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) {
if len(m.appendip_blacklist) == 0 {
return nil, false
}
return m.appendip_blacklist, true
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (m *APIKeyMutation) ClearIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
m.clearedFields[apikey.FieldIPBlacklist] = struct{}{}
}
// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation.
func (m *APIKeyMutation) IPBlacklistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPBlacklist]
return ok
}
// ResetIPBlacklist resets all changes to the "ip_blacklist" field.
func (m *APIKeyMutation) ResetIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
delete(m.clearedFields, apikey.FieldIPBlacklist)
}
// ClearUser clears the "user" edge to the User entity. // ClearUser clears the "user" edge to the User entity.
func (m *APIKeyMutation) ClearUser() { func (m *APIKeyMutation) ClearUser() {
m.cleareduser = true m.cleareduser = true
@@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *APIKeyMutation) Fields() []string { func (m *APIKeyMutation) Fields() []string {
fields := make([]string, 0, 8) fields := make([]string, 0, 10)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, apikey.FieldCreatedAt) fields = append(fields, apikey.FieldCreatedAt)
} }
@@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string {
if m.status != nil { if m.status != nil {
fields = append(fields, apikey.FieldStatus) fields = append(fields, apikey.FieldStatus)
} }
if m.ip_whitelist != nil {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.ip_blacklist != nil {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields return fields
} }
@@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
return m.GroupID() return m.GroupID()
case apikey.FieldStatus: case apikey.FieldStatus:
return m.Status() return m.Status()
case apikey.FieldIPWhitelist:
return m.IPWhitelist()
case apikey.FieldIPBlacklist:
return m.IPBlacklist()
} }
return nil, false return nil, false
} }
@@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldGroupID(ctx) return m.OldGroupID(ctx)
case apikey.FieldStatus: case apikey.FieldStatus:
return m.OldStatus(ctx) return m.OldStatus(ctx)
case apikey.FieldIPWhitelist:
return m.OldIPWhitelist(ctx)
case apikey.FieldIPBlacklist:
return m.OldIPBlacklist(ctx)
} }
return nil, fmt.Errorf("unknown APIKey field %s", name) return nil, fmt.Errorf("unknown APIKey field %s", name)
} }
@@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
} }
m.SetStatus(v) m.SetStatus(v)
return nil return nil
case apikey.FieldIPWhitelist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPWhitelist(v)
return nil
case apikey.FieldIPBlacklist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPBlacklist(v)
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }
@@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string {
if m.FieldCleared(apikey.FieldGroupID) { if m.FieldCleared(apikey.FieldGroupID) {
fields = append(fields, apikey.FieldGroupID) fields = append(fields, apikey.FieldGroupID)
} }
if m.FieldCleared(apikey.FieldIPWhitelist) {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.FieldCleared(apikey.FieldIPBlacklist) {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields return fields
} }
@@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error {
case apikey.FieldGroupID: case apikey.FieldGroupID:
m.ClearGroupID() m.ClearGroupID()
return nil return nil
case apikey.FieldIPWhitelist:
m.ClearIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ClearIPBlacklist()
return nil
} }
return fmt.Errorf("unknown APIKey nullable field %s", name) return fmt.Errorf("unknown APIKey nullable field %s", name)
} }
@@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error {
case apikey.FieldStatus: case apikey.FieldStatus:
m.ResetStatus() m.ResetStatus()
return nil return nil
case apikey.FieldIPWhitelist:
m.ResetIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ResetIPBlacklist()
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }
@@ -8396,6 +8576,7 @@ type UsageLogMutation struct {
first_token_ms *int first_token_ms *int
addfirst_token_ms *int addfirst_token_ms *int
user_agent *string user_agent *string
ip_address *string
image_count *int image_count *int
addimage_count *int addimage_count *int
image_size *string image_size *string
@@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() {
delete(m.clearedFields, usagelog.FieldUserAgent) delete(m.clearedFields, usagelog.FieldUserAgent)
} }
// SetIPAddress sets the "ip_address" field.
func (m *UsageLogMutation) SetIPAddress(s string) {
m.ip_address = &s
}
// IPAddress returns the value of the "ip_address" field in the mutation.
func (m *UsageLogMutation) IPAddress() (r string, exists bool) {
v := m.ip_address
if v == nil {
return
}
return *v, true
}
// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPAddress is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPAddress requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPAddress: %w", err)
}
return oldValue.IPAddress, nil
}
// ClearIPAddress clears the value of the "ip_address" field.
func (m *UsageLogMutation) ClearIPAddress() {
m.ip_address = nil
m.clearedFields[usagelog.FieldIPAddress] = struct{}{}
}
// IPAddressCleared returns if the "ip_address" field was cleared in this mutation.
func (m *UsageLogMutation) IPAddressCleared() bool {
_, ok := m.clearedFields[usagelog.FieldIPAddress]
return ok
}
// ResetIPAddress resets all changes to the "ip_address" field.
func (m *UsageLogMutation) ResetIPAddress() {
m.ip_address = nil
delete(m.clearedFields, usagelog.FieldIPAddress)
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (m *UsageLogMutation) SetImageCount(i int) { func (m *UsageLogMutation) SetImageCount(i int) {
m.image_count = &i m.image_count = &i
@@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 28) fields := make([]string, 0, 29)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
@@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.user_agent != nil { if m.user_agent != nil {
fields = append(fields, usagelog.FieldUserAgent) fields = append(fields, usagelog.FieldUserAgent)
} }
if m.ip_address != nil {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.image_count != nil { if m.image_count != nil {
fields = append(fields, usagelog.FieldImageCount) fields = append(fields, usagelog.FieldImageCount)
} }
@@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.FirstTokenMs() return m.FirstTokenMs()
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
return m.UserAgent() return m.UserAgent()
case usagelog.FieldIPAddress:
return m.IPAddress()
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
return m.ImageCount() return m.ImageCount()
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
@@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldFirstTokenMs(ctx) return m.OldFirstTokenMs(ctx)
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
return m.OldUserAgent(ctx) return m.OldUserAgent(ctx)
case usagelog.FieldIPAddress:
return m.OldIPAddress(ctx)
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
return m.OldImageCount(ctx) return m.OldImageCount(ctx)
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
@@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetUserAgent(v) m.SetUserAgent(v)
return nil return nil
case usagelog.FieldIPAddress:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPAddress(v)
return nil
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
v, ok := value.(int) v, ok := value.(int)
if !ok { if !ok {
@@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUserAgent) { if m.FieldCleared(usagelog.FieldUserAgent) {
fields = append(fields, usagelog.FieldUserAgent) fields = append(fields, usagelog.FieldUserAgent)
} }
if m.FieldCleared(usagelog.FieldIPAddress) {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.FieldCleared(usagelog.FieldImageSize) { if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize) fields = append(fields, usagelog.FieldImageSize)
} }
@@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
m.ClearUserAgent() m.ClearUserAgent()
return nil return nil
case usagelog.FieldIPAddress:
m.ClearIPAddress()
return nil
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
m.ClearImageSize() m.ClearImageSize()
return nil return nil
@@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
m.ResetUserAgent() m.ResetUserAgent()
return nil return nil
case usagelog.FieldIPAddress:
m.ResetIPAddress()
return nil
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
m.ResetImageCount() m.ResetImageCount()
return nil return nil

View File

@@ -533,16 +533,20 @@ func init() {
usagelogDescUserAgent := usagelogFields[24].Descriptor() usagelogDescUserAgent := usagelogFields[24].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[25].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[25].Descriptor() usagelogDescImageCount := usagelogFields[26].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[26].Descriptor() usagelogDescImageSize := usagelogFields[27].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[27].Descriptor() usagelogDescCreatedAt := usagelogFields[28].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()

View File

@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field {
field.String("status"). field.String("status").
MaxLen(20). MaxLen(20).
Default(service.StatusActive), Default(service.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
field.JSON("ip_blacklist", []string{}).
Optional().
Comment("Blocked IPs/CIDRs"),
} }
} }

View File

@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(512). MaxLen(512).
Optional(). Optional().
Nillable(), Nillable(),
field.String("ip_address").
MaxLen(45). // 支持 IPv6
Optional().
Nillable(),
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用) // 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
field.Int("image_count"). field.Int("image_count").

View File

@@ -72,6 +72,8 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms,omitempty"` FirstTokenMs *int `json:"first_token_ms,omitempty"`
// UserAgent holds the value of the "user_agent" field. // UserAgent holds the value of the "user_agent" field.
UserAgent *string `json:"user_agent,omitempty"` UserAgent *string `json:"user_agent,omitempty"`
// IPAddress holds the value of the "ip_address" field.
IPAddress *string `json:"ip_address,omitempty"`
// ImageCount holds the value of the "image_count" field. // ImageCount holds the value of the "image_count" field.
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field. // ImageSize holds the value of the "image_size" field.
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UserAgent = new(string) _m.UserAgent = new(string)
*_m.UserAgent = value.String *_m.UserAgent = value.String
} }
case usagelog.FieldIPAddress:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
} else if value.Valid {
_m.IPAddress = new(string)
*_m.IPAddress = value.String
}
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
if value, ok := values[i].(*sql.NullInt64); !ok { if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field image_count", values[i]) return fmt.Errorf("unexpected type %T for field image_count", values[i])
@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.IPAddress; v != nil {
builder.WriteString("ip_address=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_count=") builder.WriteString("image_count=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount)) builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
builder.WriteString(", ") builder.WriteString(", ")

View File

@@ -64,6 +64,8 @@ const (
FieldFirstTokenMs = "first_token_ms" FieldFirstTokenMs = "first_token_ms"
// FieldUserAgent holds the string denoting the user_agent field in the database. // FieldUserAgent holds the string denoting the user_agent field in the database.
FieldUserAgent = "user_agent" FieldUserAgent = "user_agent"
// FieldIPAddress holds the string denoting the ip_address field in the database.
FieldIPAddress = "ip_address"
// FieldImageCount holds the string denoting the image_count field in the database. // FieldImageCount holds the string denoting the image_count field in the database.
FieldImageCount = "image_count" FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database. // FieldImageSize holds the string denoting the image_size field in the database.
@@ -147,6 +149,7 @@ var Columns = []string{
FieldDurationMs, FieldDurationMs,
FieldFirstTokenMs, FieldFirstTokenMs,
FieldUserAgent, FieldUserAgent,
FieldIPAddress,
FieldImageCount, FieldImageCount,
FieldImageSize, FieldImageSize,
FieldCreatedAt, FieldCreatedAt,
@@ -199,6 +202,8 @@ var (
DefaultStream bool DefaultStream bool
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
UserAgentValidator func(string) error UserAgentValidator func(string) error
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
IPAddressValidator func(string) error
// DefaultImageCount holds the default value on creation for the "image_count" field. // DefaultImageCount holds the default value on creation for the "image_count" field.
DefaultImageCount int DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserAgent, opts...).ToFunc() return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
} }
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByImageCount orders the results by the image_count field. // ByImageCount orders the results by the image_count field.
func ByImageCount(opts ...sql.OrderTermOption) OrderOption { func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageCount, opts...).ToFunc() return sql.OrderByField(FieldImageCount, opts...).ToFunc()

View File

@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v)) return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
} }
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
func IPAddress(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ. // ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
func ImageCount(v int) predicate.UsageLog { func ImageCount(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
} }
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
func IPAddressEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
func IPAddressNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
}
// IPAddressIn applies the In predicate on the "ip_address" field.
func IPAddressIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
}
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
func IPAddressNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
}
// IPAddressGT applies the GT predicate on the "ip_address" field.
func IPAddressGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
}
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
func IPAddressGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
}
// IPAddressLT applies the LT predicate on the "ip_address" field.
func IPAddressLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
}
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
func IPAddressLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
}
// IPAddressContains applies the Contains predicate on the "ip_address" field.
func IPAddressContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
}
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
func IPAddressHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
}
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
func IPAddressHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
}
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
func IPAddressIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
}
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
func IPAddressNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
}
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
func IPAddressEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
}
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
func IPAddressContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
}
// ImageCountEQ applies the EQ predicate on the "image_count" field. // ImageCountEQ applies the EQ predicate on the "image_count" field.
func ImageCountEQ(v int) predicate.UsageLog { func ImageCountEQ(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))

View File

@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
return _c return _c
} }
// SetIPAddress sets the "ip_address" field.
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
_c.mutation.SetIPAddress(v)
return _c
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
if v != nil {
_c.SetIPAddress(*v)
}
return _c
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate { func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
_c.mutation.SetImageCount(v) _c.mutation.SetImageCount(v)
@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _c.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if _, ok := _c.mutation.ImageCount(); !ok { if _, ok := _c.mutation.ImageCount(); !ok {
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)} return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
} }
@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value) _spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
_node.UserAgent = &value _node.UserAgent = &value
} }
if value, ok := _c.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
_node.IPAddress = &value
}
if value, ok := _c.mutation.ImageCount(); ok { if value, ok := _c.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
_node.ImageCount = value _node.ImageCount = value
@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
return u return u
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
u.Set(usagelog.FieldIPAddress, v)
return u
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldIPAddress)
return u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
u.SetNull(usagelog.FieldIPAddress)
return u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert { func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
u.Set(usagelog.FieldImageCount, v) u.Set(usagelog.FieldImageCount, v)
@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
}) })
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
}) })
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {

View File

@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
return _u return _u
} }
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate { func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
_u.mutation.ResetImageCount() _u.mutation.ResetImageCount()
@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok { if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil { if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UserAgentCleared() { if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString) _spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
} }
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok { if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
} }
@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
return _u return _u
} }
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
_u.mutation.ResetImageCount() _u.mutation.ResetImageCount()
@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok { if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil { if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UserAgentCleared() { if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString) _spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
} }
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok { if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
} }

View File

@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
// CreateAPIKeyRequest represents the create API key request payload // CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// UpdateAPIKeyRequest represents the update API key request payload // UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct { type UpdateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
} }
svcReq := service.CreateAPIKeyRequest{ svcReq := service.CreateAPIKeyRequest{
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return return
} }
svcReq := service.UpdateAPIKeyRequest{} svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if req.Name != "" { if req.Name != "" {
svcReq.Name = &req.Name svcReq.Name = &req.Name
} }

View File

@@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return nil return nil
} }
return &APIKey{ return &APIKey{
ID: k.ID, ID: k.ID,
UserID: k.UserID, UserID: k.UserID,
Key: k.Key, Key: k.Key,
Name: k.Name, Name: k.Name,
GroupID: k.GroupID, GroupID: k.GroupID,
Status: k.Status, Status: k.Status,
CreatedAt: k.CreatedAt, IPWhitelist: k.IPWhitelist,
UpdatedAt: k.UpdatedAt, IPBlacklist: k.IPBlacklist,
User: UserFromServiceShallow(k.User), CreatedAt: k.CreatedAt,
Group: GroupFromServiceShallow(k.Group), UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
} }
} }
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO. // usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included. // The account parameter allows caller to control what Account info is included.
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog { // The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil { if l == nil {
return nil return nil
} }
return &UsageLog{ result := &UsageLog{
ID: l.ID, ID: l.ID,
UserID: l.UserID, UserID: l.UserID,
APIKeyID: l.APIKeyID, APIKeyID: l.APIKeyID,
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
Group: GroupFromServiceShallow(l.Group), Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription), Subscription: UserSubscriptionFromService(l.Subscription),
} }
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
} }
// UsageLogFromService converts a service UsageLog to DTO for regular users. // UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see account information. // It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog { func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil) return usageLogFromServiceBase(l, nil, false)
} }
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users. // UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only). // It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog { func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
if l == nil { if l == nil {
return nil return nil
} }
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account)) return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
} }
func SettingFromService(s *service.Setting) *Setting { func SettingFromService(s *service.Setting) *Setting {

View File

@@ -20,14 +20,16 @@ type User struct {
} }
type APIKey struct { type APIKey struct {
ID int64 `json:"id"` ID int64 `json:"id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
Key string `json:"key"` Key string `json:"key"`
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status string `json:"status"` Status string `json:"status"`
CreatedAt time.Time `json:"created_at"` IPWhitelist []string `json:"ip_whitelist"`
UpdatedAt time.Time `json:"updated_at"` IPBlacklist []string `json:"ip_blacklist"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
@@ -187,6 +189,9 @@ type UsageLog struct {
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取 User-Agent // 获取 User-Agent
userAgent := c.Request.UserAgent() userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// 0. 检查wait队列是否已满 // 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(subject.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 异步记录使用量subscription已在函数开头获取 // 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: cip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent) }(result, account, userAgent, clientIP)
return return
} }
} }
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 异步记录使用量subscription已在函数开头获取 // 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: cip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent) }(result, account, userAgent, clientIP)
return return
} }
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 获取 User-Agent // 获取 User-Agent
userAgent := c.Request.UserAgent() userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// For Gemini native API, do not send Claude-style ping frames. // For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 6) record usage async // 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: cip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent) }(result, account, userAgent, clientIP)
return return
} }
} }

View File

@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -94,6 +95,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// For non-Codex CLI requests, set default instructions // For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
if !openai.IsCodexCLIRequest(userAgent) { if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body // Re-serialize body
@@ -242,7 +247,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
// Async record usage // Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) { go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
@@ -252,10 +257,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: cip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent) }(result, account, userAgent, clientIP)
return return
} }
} }

View File

@@ -0,0 +1,168 @@
// Package ip 提供客户端 IP 地址提取工具。
package ip
import (
"net"
"strings"
"github.com/gin-gonic/gin"
)
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
// 按以下优先级检查 Header
// 1. CF-Connecting-IP (Cloudflare)
// 2. X-Real-IP (Nginx)
// 3. X-Forwarded-For (取第一个非私有 IP)
// 4. c.ClientIP() (Gin 内置方法)
func GetClientIP(c *gin.Context) string {
// 1. Cloudflare
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
return normalizeIP(ip)
}
// 2. Nginx X-Real-IP
if ip := c.GetHeader("X-Real-IP"); ip != "" {
return normalizeIP(ip)
}
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" && !isPrivateIP(ip) {
return normalizeIP(ip)
}
}
// 如果都是私有 IP返回第一个
if len(ips) > 0 {
return normalizeIP(strings.TrimSpace(ips[0]))
}
}
// 4. Gin 内置方法
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1"
if host, _, err := net.SplitHostPort(ip); err == nil {
return host
}
return ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围
privateBlocks := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil {
continue
}
if cidr.Contains(ip) {
return true
}
}
return false
}
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR
// pattern 可以是:
// - 单个 IP: "192.168.1.100"
// - CIDR 范围: "192.168.1.0/24"
func MatchesPattern(clientIP, pattern string) bool {
ip := net.ParseIP(clientIP)
if ip == nil {
return false
}
// 尝试解析为 CIDR
if strings.Contains(pattern, "/") {
_, cidr, err := net.ParseCIDR(pattern)
if err != nil {
return false
}
return cidr.Contains(ip)
}
// 作为单个 IP 处理
patternIP := net.ParseIP(pattern)
if patternIP == nil {
return false
}
return ip.Equal(patternIP)
}
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
func MatchesAnyPattern(clientIP string, patterns []string) bool {
for _, pattern := range patterns {
if MatchesPattern(clientIP, pattern) {
return true
}
}
return false
}
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
// 返回值:(是否允许, 拒绝原因)
// 逻辑:
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
// 2. 如果白名单不为空IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单如果设置了白名单IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
return false, "access denied"
}
return true, ""
}
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
func ValidateIPPattern(pattern string) bool {
if strings.Contains(pattern, "/") {
_, _, err := net.ParseCIDR(pattern)
return err == nil
}
return net.ParseIP(pattern) != nil
}
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
// 返回无效的模式列表。
func ValidateIPPatterns(patterns []string) []string {
var invalid []string
for _, p := range patterns {
if !ValidateIPPattern(p) {
invalid = append(invalid, p)
}
}
return invalid
}

View File

@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create(). builder := r.client.APIKey.Create().
SetUserID(key.UserID). SetUserID(key.UserID).
SetKey(key.Key). SetKey(key.Key).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
SetNillableGroupID(key.GroupID). SetNillableGroupID(key.GroupID)
Save(ctx)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil { if err == nil {
key.ID = created.ID key.ID = created.ID
key.CreatedAt = created.CreatedAt key.CreatedAt = created.CreatedAt
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID() builder.ClearGroupID()
} }
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx) affected, err := builder.Save(ctx)
if err != nil { if err != nil {
return err return err
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
return nil return nil
} }
out := &service.APIKey{ out := &service.APIKey{
ID: m.ID, ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
Key: m.Key, Key: m.Key,
Name: m.Name, Name: m.Name,
Status: m.Status, Status: m.Status,
CreatedAt: m.CreatedAt, IPWhitelist: m.IPWhitelist,
UpdatedAt: m.UpdatedAt, IPBlacklist: m.IPBlacklist,
GroupID: m.GroupID, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms, duration_ms,
first_token_ms, first_token_ms,
user_agent, user_agent,
ip_address,
image_count, image_count,
image_size, image_size,
created_at created_at
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
var requestIDArg any var requestIDArg any
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration, duration,
firstToken, firstToken,
userAgent, userAgent,
ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
createdAt, createdAt,
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
durationMs sql.NullInt64 durationMs sql.NullInt64
firstTokenMs sql.NullInt64 firstTokenMs sql.NullInt64
userAgent sql.NullString userAgent sql.NullString
ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
createdAt time.Time createdAt time.Time
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&durationMs, &durationMs,
&firstTokenMs, &firstTokenMs,
&userAgent, &userAgent,
&ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&createdAt, &createdAt,
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if userAgent.Valid { if userAgent.Valid {
log.UserAgent = &userAgent.String log.UserAgent = &userAgent.String
} }
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }

View File

@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One", "name": "Key One",
"group_id": null, "group_id": null,
"status": "active", "status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One", "name": "Key One",
"group_id": null, "group_id": null,
"status": "active", "status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户 // 检查关联的用户
if apiKey.User == nil { if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")

View File

@@ -3,16 +3,18 @@ package service
import "time" import "time"
type APIKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
Name string Name string
GroupID *int64 GroupID *int64
Status string Status string
CreatedAt time.Time IPWhitelist []string
UpdatedAt time.Time IPBlacklist []string
User *User CreatedAt time.Time
Group *Group UpdatedAt time.Time
User *User
Group *Group
} }
func (k *APIKey) IsActive() bool { func (k *APIKey) IsActive() bool {

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
) )
@@ -20,6 +21,7 @@ var (
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
) )
const ( const (
@@ -57,16 +59,20 @@ type APIKeyCache interface {
// CreateAPIKeyRequest 创建API Key请求 // CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// UpdateAPIKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct { type UpdateAPIKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
} }
// APIKeyService API Key服务 // APIKeyService API Key服务
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组) // 验证分组权限(如果指定了分组)
if req.GroupID != nil { if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录 // 创建API Key记录
apiKey := &APIKey{ apiKey := &APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
Status: StatusActive, Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
} }
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, ErrInsufficientPerms return nil, ErrInsufficientPerms
} }
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段 // 更新字段
if req.Name != nil { if req.Name != nil {
apiKey.Name = *req.Name apiKey.Name = *req.Name
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
} }
} }
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err) return nil, fmt.Errorf("update api key: %w", err)
} }

View File

@@ -2247,6 +2247,7 @@ type RecordUsageInput struct {
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -2337,6 +2338,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.UserAgent = &input.UserAgent usageLog.UserAgent = &input.UserAgent
} }
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联 // 添加分组和订阅关联
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID usageLog.GroupID = apiKey.GroupID

View File

@@ -1197,6 +1197,7 @@ type OpenAIRecordUsageInput struct {
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
@@ -1271,6 +1272,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.UserAgent = &input.UserAgent usageLog.UserAgent = &input.UserAgent
} }
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID usageLog.GroupID = apiKey.GroupID
} }

View File

@@ -39,6 +39,7 @@ type UsageLog struct {
DurationMs *int DurationMs *int
FirstTokenMs *int FirstTokenMs *int
UserAgent *string UserAgent *string
IPAddress *string
// 图片生成字段 // 图片生成字段
ImageCount int ImageCount int

View File

@@ -0,0 +1,5 @@
-- Add IP address field to usage_logs table for request tracking (admin-only visibility)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS ip_address VARCHAR(45);
-- Create index for IP address queries
CREATE INDEX IF NOT EXISTS idx_usage_logs_ip_address ON usage_logs(ip_address);

View File

@@ -0,0 +1,9 @@
-- Add IP restriction fields to api_keys table
-- ip_whitelist: JSON array of allowed IPs/CIDRs (if set, only these IPs can use the key)
-- ip_blacklist: JSON array of blocked IPs/CIDRs (these IPs are always blocked)
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_whitelist JSONB DEFAULT NULL;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_blacklist JSONB DEFAULT NULL;
COMMENT ON COLUMN api_keys.ip_whitelist IS 'JSON array of allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]';
COMMENT ON COLUMN api_keys.ip_blacklist IS 'JSON array of blocked IPs/CIDRs, e.g. ["1.2.3.4", "5.6.0.0/16"]';

BIN
backend/repository.test Executable file

Binary file not shown.

View File

@@ -0,0 +1,93 @@
# =============================================================================
# Sub2API Docker Compose - Standalone Configuration
# =============================================================================
# This configuration runs only the Sub2API application.
# PostgreSQL and Redis must be provided externally.
#
# Usage:
# 1. Copy .env.example to .env and configure database/redis connection
# 2. docker-compose -f docker-compose.standalone.yml up -d
# 3. Access: http://localhost:8080
# =============================================================================
services:
sub2api:
image: weishaw/sub2api:latest
container_name: sub2api
restart: unless-stopped
ulimits:
nofile:
soft: 100000
hard: 100000
ports:
- "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080"
volumes:
- sub2api_data:/app/data
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
# =======================================================================
# Auto Setup
# =======================================================================
- AUTO_SETUP=true
# =======================================================================
# Server Configuration
# =======================================================================
- SERVER_HOST=0.0.0.0
- SERVER_PORT=8080
- SERVER_MODE=${SERVER_MODE:-release}
- RUN_MODE=${RUN_MODE:-standard}
# =======================================================================
# Database Configuration (PostgreSQL) - Required
# =======================================================================
- DATABASE_HOST=${DATABASE_HOST:?DATABASE_HOST is required}
- DATABASE_PORT=${DATABASE_PORT:-5432}
- DATABASE_USER=${DATABASE_USER:-sub2api}
- DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required}
- DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api}
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
# =======================================================================
# Redis Configuration - Required
# =======================================================================
- REDIS_HOST=${REDIS_HOST:?REDIS_HOST is required}
- REDIS_PORT=${REDIS_PORT:-6379}
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0}
# =======================================================================
# Admin Account (auto-created on first run)
# =======================================================================
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
# =======================================================================
# JWT Configuration
# =======================================================================
- JWT_SECRET=${JWT_SECRET:-}
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
# =======================================================================
# Timezone Configuration
# =======================================================================
- TZ=${TZ:-Asia/Shanghai}
# =======================================================================
# Gemini OAuth Configuration (optional)
# =======================================================================
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
volumes:
sub2api_data:
driver: local

View File

@@ -64,7 +64,6 @@ export async function getStats(params: {
group_id?: number group_id?: number
model?: string model?: string
stream?: boolean stream?: boolean
billing_type?: number
period?: string period?: string
start_date?: string start_date?: string
end_date?: string end_date?: string

View File

@@ -42,12 +42,16 @@ export async function getById(id: number): Promise<ApiKey> {
* @param name - Key name * @param name - Key name
* @param groupId - Optional group ID * @param groupId - Optional group ID
* @param customKey - Optional custom key value * @param customKey - Optional custom key value
* @param ipWhitelist - Optional IP whitelist
* @param ipBlacklist - Optional IP blacklist
* @returns Created API key * @returns Created API key
*/ */
export async function create( export async function create(
name: string, name: string,
groupId?: number | null, groupId?: number | null,
customKey?: string customKey?: string,
ipWhitelist?: string[],
ipBlacklist?: string[]
): Promise<ApiKey> { ): Promise<ApiKey> {
const payload: CreateApiKeyRequest = { name } const payload: CreateApiKeyRequest = { name }
if (groupId !== undefined) { if (groupId !== undefined) {
@@ -56,6 +60,12 @@ export async function create(
if (customKey) { if (customKey) {
payload.custom_key = customKey payload.custom_key = customKey
} }
if (ipWhitelist && ipWhitelist.length > 0) {
payload.ip_whitelist = ipWhitelist
}
if (ipBlacklist && ipBlacklist.length > 0) {
payload.ip_blacklist = ipBlacklist
}
const { data } = await apiClient.post<ApiKey>('/keys', payload) const { data } = await apiClient.post<ApiKey>('/keys', payload)
return data return data

View File

@@ -127,12 +127,6 @@
<Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" /> <Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" />
</div> </div>
<!-- Billing Type Filter -->
<div class="w-full sm:w-auto sm:min-w-[180px]">
<label class="input-label">{{ t('usage.billingType') }}</label>
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
</div>
<!-- Group Filter --> <!-- Group Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]"> <div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.group') }}</label> <label class="input-label">{{ t('admin.usage.group') }}</label>
@@ -227,12 +221,6 @@ const streamTypeOptions = ref<SelectOption[]>([
{ value: false, label: t('usage.sync') } { value: false, label: t('usage.sync') }
]) ])
const billingTypeOptions = ref<SelectOption[]>([
{ value: null, label: t('admin.usage.allBillingTypes') },
{ value: 1, label: t('usage.subscription') },
{ value: 0, label: t('usage.balance') }
])
const emitChange = () => emit('change') const emitChange = () => emit('change')
const updateStartDate = (value: string) => { const updateStartDate = (value: string) => {

View File

@@ -96,12 +96,6 @@
</div> </div>
</template> </template>
<template #cell-billing_type="{ row }">
<span class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" :class="row.billing_type === 1 ? 'bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200' : 'bg-emerald-100 text-emerald-800 dark:bg-emerald-900 dark:text-emerald-200'">
{{ row.billing_type === 1 ? t('usage.subscription') : t('usage.balance') }}
</span>
</template>
<template #cell-first_token="{ row }"> <template #cell-first_token="{ row }">
<span v-if="row.first_token_ms != null" class="text-sm text-gray-600 dark:text-gray-400">{{ formatDuration(row.first_token_ms) }}</span> <span v-if="row.first_token_ms != null" class="text-sm text-gray-600 dark:text-gray-400">{{ formatDuration(row.first_token_ms) }}</span>
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span> <span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
@@ -120,6 +114,11 @@
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span> <span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
</template> </template>
<template #cell-ip_address="{ row }">
<span v-if="row.ip_address" class="text-sm font-mono text-gray-600 dark:text-gray-400">{{ row.ip_address }}</span>
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
</template>
<template #empty><EmptyState :message="t('usage.noRecords')" /></template> <template #empty><EmptyState :message="t('usage.noRecords')" /></template>
</DataTable> </DataTable>
</div> </div>
@@ -249,11 +248,11 @@ const cols = computed(() => [
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false },
{ key: 'billing_type', label: t('usage.billingType'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false },
{ key: 'duration', label: t('usage.duration'), sortable: false }, { key: 'duration', label: t('usage.duration'), sortable: false },
{ key: 'created_at', label: t('usage.time'), sortable: true }, { key: 'created_at', label: t('usage.time'), sortable: true },
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false } { key: 'user_agent', label: t('usage.userAgent'), sortable: false },
{ key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false }
]) ])
const formatCacheTokens = (tokens: number): string => { const formatCacheTokens = (tokens: number): string => {

View File

@@ -370,6 +370,14 @@ export default {
customKeyTooShort: 'Custom key must be at least 16 characters', customKeyTooShort: 'Custom key must be at least 16 characters',
customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens', customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens',
customKeyRequired: 'Please enter a custom key', customKeyRequired: 'Please enter a custom key',
ipRestriction: 'IP Restriction',
ipWhitelist: 'IP Whitelist',
ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8',
ipWhitelistHint: 'One IP or CIDR per line. Only these IPs can use this key when set.',
ipBlacklist: 'IP Blacklist',
ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16',
ipBlacklistHint: 'One IP or CIDR per line. These IPs will be blocked from using this key.',
ipRestrictionEnabled: 'IP restriction enabled',
ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.', ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.',
ccsClientSelect: { ccsClientSelect: {
title: 'Select Client', title: 'Select Client',
@@ -430,9 +438,6 @@ export default {
exportFailed: 'Failed to export usage data', exportFailed: 'Failed to export usage data',
exportExcelSuccess: 'Usage data exported successfully (Excel format)', exportExcelSuccess: 'Usage data exported successfully (Excel format)',
exportExcelFailed: 'Failed to export usage data', exportExcelFailed: 'Failed to export usage data',
billingType: 'Billing',
balance: 'Balance',
subscription: 'Subscription',
imageUnit: ' images', imageUnit: ' images',
userAgent: 'User-Agent' userAgent: 'User-Agent'
}, },
@@ -1735,7 +1740,6 @@ export default {
allAccounts: 'All Accounts', allAccounts: 'All Accounts',
allGroups: 'All Groups', allGroups: 'All Groups',
allTypes: 'All Types', allTypes: 'All Types',
allBillingTypes: 'All Billing',
inputCost: 'Input Cost', inputCost: 'Input Cost',
outputCost: 'Output Cost', outputCost: 'Output Cost',
cacheCreationCost: 'Cache Creation Cost', cacheCreationCost: 'Cache Creation Cost',
@@ -1744,7 +1748,8 @@ export default {
outputTokens: 'Output Tokens', outputTokens: 'Output Tokens',
cacheCreationTokens: 'Cache Creation Tokens', cacheCreationTokens: 'Cache Creation Tokens',
cacheReadTokens: 'Cache Read Tokens', cacheReadTokens: 'Cache Read Tokens',
failedToLoad: 'Failed to load usage records' failedToLoad: 'Failed to load usage records',
ipAddress: 'IP'
}, },
// Settings // Settings

View File

@@ -367,6 +367,14 @@ export default {
customKeyTooShort: '自定义密钥至少需要16个字符', customKeyTooShort: '自定义密钥至少需要16个字符',
customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符', customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符',
customKeyRequired: '请输入自定义密钥', customKeyRequired: '请输入自定义密钥',
ipRestriction: 'IP 限制',
ipWhitelist: 'IP 白名单',
ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8',
ipWhitelistHint: '每行一个 IP 或 CIDR设置后仅允许这些 IP 使用此密钥',
ipBlacklist: 'IP 黑名单',
ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16',
ipBlacklistHint: '每行一个 IP 或 CIDR这些 IP 将被禁止使用此密钥',
ipRestrictionEnabled: '已配置 IP 限制',
ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。', ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。',
ccsClientSelect: { ccsClientSelect: {
title: '选择客户端', title: '选择客户端',
@@ -427,9 +435,6 @@ export default {
exportFailed: '使用数据导出失败', exportFailed: '使用数据导出失败',
exportExcelSuccess: '使用数据导出成功Excel格式', exportExcelSuccess: '使用数据导出成功Excel格式',
exportExcelFailed: '使用数据导出失败', exportExcelFailed: '使用数据导出失败',
billingType: '消费类型',
balance: '余额',
subscription: '订阅',
imageUnit: '张', imageUnit: '张',
userAgent: 'User-Agent' userAgent: 'User-Agent'
}, },
@@ -1880,7 +1885,6 @@ export default {
allAccounts: '全部账户', allAccounts: '全部账户',
allGroups: '全部分组', allGroups: '全部分组',
allTypes: '全部类型', allTypes: '全部类型',
allBillingTypes: '全部计费',
inputCost: '输入成本', inputCost: '输入成本',
outputCost: '输出成本', outputCost: '输出成本',
cacheCreationCost: '缓存创建成本', cacheCreationCost: '缓存创建成本',
@@ -1889,7 +1893,8 @@ export default {
outputTokens: '输出 Token', outputTokens: '输出 Token',
cacheCreationTokens: '缓存创建 Token', cacheCreationTokens: '缓存创建 Token',
cacheReadTokens: '缓存读取 Token', cacheReadTokens: '缓存读取 Token',
failedToLoad: '加载使用记录失败' failedToLoad: '加载使用记录失败',
ipAddress: 'IP'
}, },
// Settings // Settings

View File

@@ -279,6 +279,8 @@ export interface ApiKey {
name: string name: string
group_id: number | null group_id: number | null
status: 'active' | 'inactive' status: 'active' | 'inactive'
ip_whitelist: string[]
ip_blacklist: string[]
created_at: string created_at: string
updated_at: string updated_at: string
group?: Group group?: Group
@@ -288,12 +290,16 @@ export interface CreateApiKeyRequest {
name: string name: string
group_id?: number | null group_id?: number | null
custom_key?: string // Optional custom API Key custom_key?: string // Optional custom API Key
ip_whitelist?: string[]
ip_blacklist?: string[]
} }
export interface UpdateApiKeyRequest { export interface UpdateApiKeyRequest {
name?: string name?: string
group_id?: number | null group_id?: number | null
status?: 'active' | 'inactive' status?: 'active' | 'inactive'
ip_whitelist?: string[]
ip_blacklist?: string[]
} }
export interface CreateGroupRequest { export interface CreateGroupRequest {
@@ -560,9 +566,6 @@ export interface UpdateProxyRequest {
export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription'
// 消费类型: 0=钱包余额, 1=订阅套餐
export type BillingType = 0 | 1
export interface UsageLog { export interface UsageLog {
id: number id: number
user_id: number user_id: number
@@ -589,7 +592,6 @@ export interface UsageLog {
actual_cost: number actual_cost: number
rate_multiplier: number rate_multiplier: number
billing_type: BillingType
stream: boolean stream: boolean
duration_ms: number duration_ms: number
first_token_ms: number | null first_token_ms: number | null
@@ -601,6 +603,9 @@ export interface UsageLog {
// User-Agent // User-Agent
user_agent: string | null user_agent: string | null
// IP 地址(仅管理员可见)
ip_address: string | null
created_at: string created_at: string
user?: User user?: User
@@ -830,7 +835,6 @@ export interface UsageQueryParams {
group_id?: number group_id?: number
model?: string model?: string
stream?: boolean stream?: boolean
billing_type?: number
start_date?: string start_date?: string
end_date?: string end_date?: string
} }

View File

@@ -95,8 +95,8 @@ const exportToExcel = async () => {
t('admin.usage.inputCost'), t('admin.usage.outputCost'), t('admin.usage.inputCost'), t('admin.usage.outputCost'),
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'), t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
t('usage.rate'), t('usage.original'), t('usage.billed'), t('usage.rate'), t('usage.original'), t('usage.billed'),
t('usage.billingType'), t('usage.firstToken'), t('usage.duration'), t('usage.firstToken'), t('usage.duration'),
t('admin.usage.requestId'), t('usage.userAgent') t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
] ]
const rows = all.map(log => [ const rows = all.map(log => [
log.created_at, log.created_at,
@@ -117,11 +117,11 @@ const exportToExcel = async () => {
log.rate_multiplier?.toFixed(2) || '1.00', log.rate_multiplier?.toFixed(2) || '1.00',
log.total_cost?.toFixed(6) || '0.000000', log.total_cost?.toFixed(6) || '0.000000',
log.actual_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
log.billing_type === 1 ? t('usage.subscription') : t('usage.balance'),
log.first_token_ms ?? '', log.first_token_ms ?? '',
log.duration_ms, log.duration_ms,
log.request_id || '', log.request_id || '',
log.user_agent || '' log.user_agent || '',
log.ip_address || ''
]) ])
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows]) const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
const wb = XLSX.utils.book_new() const wb = XLSX.utils.book_new()

View File

@@ -46,8 +46,17 @@
</div> </div>
</template> </template>
<template #cell-name="{ value }"> <template #cell-name="{ value, row }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <div class="flex items-center gap-1.5">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
<Icon
v-if="row.ip_whitelist?.length > 0 || row.ip_blacklist?.length > 0"
name="shield"
size="sm"
class="text-blue-500"
:title="t('keys.ipRestrictionEnabled')"
/>
</div>
</template> </template>
<template #cell-group="{ row }"> <template #cell-group="{ row }">
@@ -278,6 +287,52 @@
:placeholder="t('keys.selectStatus')" :placeholder="t('keys.selectStatus')"
/> />
</div> </div>
<!-- IP Restriction Section -->
<div class="space-y-3">
<div class="flex items-center justify-between">
<label class="input-label mb-0">{{ t('keys.ipRestriction') }}</label>
<button
type="button"
@click="formData.enable_ip_restriction = !formData.enable_ip_restriction"
:class="[
'relative inline-flex h-5 w-9 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none',
formData.enable_ip_restriction ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-4 w-4 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
formData.enable_ip_restriction ? 'translate-x-4' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="formData.enable_ip_restriction" class="space-y-4 pt-2">
<div>
<label class="input-label">{{ t('keys.ipWhitelist') }}</label>
<textarea
v-model="formData.ip_whitelist"
rows="3"
class="input font-mono text-sm"
:placeholder="t('keys.ipWhitelistPlaceholder')"
/>
<p class="input-hint">{{ t('keys.ipWhitelistHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('keys.ipBlacklist') }}</label>
<textarea
v-model="formData.ip_blacklist"
rows="3"
class="input font-mono text-sm"
:placeholder="t('keys.ipBlacklistPlaceholder')"
/>
<p class="input-hint">{{ t('keys.ipBlacklistHint') }}</p>
</div>
</div>
</div>
</form> </form>
<template #footer> <template #footer>
<div class="flex justify-end gap-3"> <div class="flex justify-end gap-3">
@@ -528,7 +583,10 @@ const formData = ref({
group_id: null as number | null, group_id: null as number | null,
status: 'active' as 'active' | 'inactive', status: 'active' as 'active' | 'inactive',
use_custom_key: false, use_custom_key: false,
custom_key: '' custom_key: '',
enable_ip_restriction: false,
ip_whitelist: '',
ip_blacklist: ''
}) })
// 自定义Key验证 // 自定义Key验证
@@ -664,12 +722,16 @@ const handlePageSizeChange = (pageSize: number) => {
const editKey = (key: ApiKey) => { const editKey = (key: ApiKey) => {
selectedKey.value = key selectedKey.value = key
const hasIPRestriction = (key.ip_whitelist?.length > 0) || (key.ip_blacklist?.length > 0)
formData.value = { formData.value = {
name: key.name, name: key.name,
group_id: key.group_id, group_id: key.group_id,
status: key.status, status: key.status,
use_custom_key: false, use_custom_key: false,
custom_key: '' custom_key: '',
enable_ip_restriction: hasIPRestriction,
ip_whitelist: (key.ip_whitelist || []).join('\n'),
ip_blacklist: (key.ip_blacklist || []).join('\n')
} }
showEditModal.value = true showEditModal.value = true
} }
@@ -751,14 +813,26 @@ const handleSubmit = async () => {
} }
} }
// Parse IP lists only if IP restriction is enabled
const parseIPList = (text: string): string[] =>
text.split('\n').map(ip => ip.trim()).filter(ip => ip.length > 0)
const ipWhitelist = formData.value.enable_ip_restriction ? parseIPList(formData.value.ip_whitelist) : []
const ipBlacklist = formData.value.enable_ip_restriction ? parseIPList(formData.value.ip_blacklist) : []
submitting.value = true submitting.value = true
try { try {
if (showEditModal.value && selectedKey.value) { if (showEditModal.value && selectedKey.value) {
await keysAPI.update(selectedKey.value.id, formData.value) await keysAPI.update(selectedKey.value.id, {
name: formData.value.name,
group_id: formData.value.group_id,
status: formData.value.status,
ip_whitelist: ipWhitelist,
ip_blacklist: ipBlacklist
})
appStore.showSuccess(t('keys.keyUpdatedSuccess')) appStore.showSuccess(t('keys.keyUpdatedSuccess'))
} else { } else {
const customKey = formData.value.use_custom_key ? formData.value.custom_key : undefined const customKey = formData.value.use_custom_key ? formData.value.custom_key : undefined
await keysAPI.create(formData.value.name, formData.value.group_id, customKey) await keysAPI.create(formData.value.name, formData.value.group_id, customKey, ipWhitelist, ipBlacklist)
appStore.showSuccess(t('keys.keyCreatedSuccess')) appStore.showSuccess(t('keys.keyCreatedSuccess'))
// Only advance tour if active, on submit step, and creation succeeded // Only advance tour if active, on submit step, and creation succeeded
if (onboardingStore.isCurrentStep('[data-tour="key-form-submit"]')) { if (onboardingStore.isCurrentStep('[data-tour="key-form-submit"]')) {
@@ -805,7 +879,10 @@ const closeModals = () => {
group_id: null, group_id: null,
status: 'active', status: 'active',
use_custom_key: false, use_custom_key: false,
custom_key: '' custom_key: '',
enable_ip_restriction: false,
ip_whitelist: '',
ip_blacklist: ''
} }
} }

View File

@@ -273,19 +273,6 @@
</div> </div>
</template> </template>
<template #cell-billing_type="{ row }">
<span
class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium"
:class="
row.billing_type === 1
? 'bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200'
: 'bg-emerald-100 text-emerald-800 dark:bg-emerald-900 dark:text-emerald-200'
"
>
{{ row.billing_type === 1 ? t('usage.subscription') : t('usage.balance') }}
</span>
</template>
<template #cell-first_token="{ row }"> <template #cell-first_token="{ row }">
<span <span
v-if="row.first_token_ms != null" v-if="row.first_token_ms != null"
@@ -482,7 +469,6 @@ const columns = computed<Column[]>(() => [
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false },
{ key: 'billing_type', label: t('usage.billingType'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false },
{ key: 'duration', label: t('usage.duration'), sortable: false }, { key: 'duration', label: t('usage.duration'), sortable: false },
{ key: 'created_at', label: t('usage.time'), sortable: true }, { key: 'created_at', label: t('usage.time'), sortable: true },
@@ -745,7 +731,6 @@ const exportToCSV = async () => {
'Rate Multiplier', 'Rate Multiplier',
'Billed Cost', 'Billed Cost',
'Original Cost', 'Original Cost',
'Billing Type',
'First Token (ms)', 'First Token (ms)',
'Duration (ms)' 'Duration (ms)'
] ]
@@ -762,7 +747,6 @@ const exportToCSV = async () => {
log.rate_multiplier, log.rate_multiplier,
log.actual_cost.toFixed(8), log.actual_cost.toFixed(8),
log.total_cost.toFixed(8), log.total_cost.toFixed(8),
log.billing_type === 1 ? 'Subscription' : 'Balance',
log.first_token_ms ?? '', log.first_token_ms ?? '',
log.duration_ms log.duration_ms
].map(escapeCSVValue) ].map(escapeCSVValue)