package middleware import ( "context" "database/sql" "github.com/gofiber/fiber/v2" "github.com/gooseek/backend/internal/usage" "github.com/gooseek/backend/pkg/metrics" ) type LLMLimitsConfig struct { UsageRepo *usage.Repository } func LLMLimits(config LLMLimitsConfig) fiber.Handler { return func(c *fiber.Ctx) error { userID := GetUserID(c) clientIP := c.IP() if userID == "" { metrics.RecordLLMUnauthorized("no_user_id", clientIP) metrics.RecordSecurityEvent("unauthorized_llm_access", clientIP, "anonymous") return c.Status(401).JSON(fiber.Map{ "error": "Authentication required", }) } tier := GetUserTier(c) if tier == "" { tier = "free" } if config.UsageRepo != nil { allowed, reason := config.UsageRepo.CheckLLMLimits(c.Context(), userID, tier) if !allowed { limits := usage.GetLimits(tier) if tier == "free" { metrics.RecordFreeTierLimitExceeded(userID, reason) metrics.RecordSecurityEvent("free_tier_limit_exceeded", clientIP, userID) } metrics.RecordRateLimitHit("llm-svc", clientIP, reason) return c.Status(429).JSON(fiber.Map{ "error": reason, "tier": tier, "dailyLimit": limits.LLMRequestsPerDay, "tokenLimit": limits.LLMTokensPerDay, "upgradeUrl": "/settings/billing", }) } } return c.Next() } } type UsageTracker struct { repo *usage.Repository } func NewUsageTracker(db *sql.DB) *UsageTracker { return &UsageTracker{ repo: usage.NewRepository(db), } } func (t *UsageTracker) RunMigrations(ctx context.Context) error { return t.repo.RunMigrations(ctx) } func (t *UsageTracker) TrackAPIRequest(c *fiber.Ctx) { userID := GetUserID(c) if userID == "" { return } tier := GetUserTier(c) if tier == "" { tier = "free" } go t.repo.IncrementAPIRequests(context.Background(), userID, tier) } func (t *UsageTracker) TrackLLMRequest(ctx context.Context, userID, tier string, tokens int) { if userID == "" { return } if tier == "" { tier = "free" } go t.repo.IncrementLLMUsage(ctx, userID, tier, tokens) } func (t *UsageTracker) TrackSearchRequest(c *fiber.Ctx) { userID := GetUserID(c) if userID == "" { return } tier := GetUserTier(c) if tier == "" { tier = "free" } go t.repo.IncrementSearchRequests(context.Background(), userID, tier) } func (t *UsageTracker) GetRepository() *usage.Repository { return t.repo } func UsageTracking(tracker *UsageTracker) fiber.Handler { return func(c *fiber.Ctx) error { if tracker != nil { tracker.TrackAPIRequest(c) } return c.Next() } }