package middleware import ( "context" "database/sql" "log" "github.com/gofiber/fiber/v2" "github.com/gooseek/backend/internal/usage" "github.com/gooseek/backend/pkg/email" "github.com/gooseek/backend/pkg/metrics" ) type LLMLimitsConfig struct { UsageRepo *usage.Repository EmailSender *email.Sender } func LLMLimits(config LLMLimitsConfig) fiber.Handler { return func(c *fiber.Ctx) error { userID := GetUserID(c) clientIP := c.IP() if userID == "" { metrics.RecordLLMUnauthorized("no_user_id", clientIP) metrics.RecordSecurityEvent("unauthorized_llm_access", clientIP, "anonymous") return c.Status(401).JSON(fiber.Map{ "error": "Authentication required", }) } tier := GetUserTier(c) if tier == "" { tier = "free" } if config.UsageRepo != nil { allowed, reason := config.UsageRepo.CheckLLMLimits(c.Context(), userID, tier) if !allowed { limits := usage.GetLimits(tier) if tier == "free" { metrics.RecordFreeTierLimitExceeded(userID, reason) metrics.RecordSecurityEvent("free_tier_limit_exceeded", clientIP, userID) } metrics.RecordRateLimitHit("llm-svc", clientIP, reason) sendLimitEmail(config, userID, tier, limits.LLMRequestsPerDay, limits.LLMRequestsPerDay) return c.Status(429).JSON(fiber.Map{ "error": reason, "tier": tier, "dailyLimit": limits.LLMRequestsPerDay, "tokenLimit": limits.LLMTokensPerDay, "upgradeUrl": "/settings/billing", }) } checkLimitWarning(config, userID, tier) } return c.Next() } } func checkLimitWarning(config LLMLimitsConfig, userID, tier string) { if config.UsageRepo == nil || config.EmailSender == nil || !config.EmailSender.IsConfigured() { return } go func() { todayUsage, err := config.UsageRepo.GetTodayUsage(context.Background(), userID) if err != nil || todayUsage == nil { return } limits := usage.GetLimits(tier) if limits.LLMRequestsPerDay == 0 { return } percentage := (todayUsage.LLMRequests * 100) / limits.LLMRequestsPerDay if percentage >= 80 && percentage < 100 { userEmail := getUserEmail(config.UsageRepo, userID) if userEmail != "" { if err := config.EmailSender.SendLimitWarning(userEmail, "", todayUsage.LLMRequests, limits.LLMRequestsPerDay, tier); err != nil { log.Printf("[email] Limit warning send error: %v", err) } } } }() } func sendLimitEmail(config LLMLimitsConfig, userID, tier string, usageCount, limitCount int) { if config.EmailSender == nil || !config.EmailSender.IsConfigured() { return } go func() { userEmail := getUserEmail(config.UsageRepo, userID) if userEmail != "" { if err := config.EmailSender.SendLimitWarning(userEmail, "", usageCount, limitCount, tier); err != nil { log.Printf("[email] Limit exceeded email error: %v", err) } } }() } func getUserEmail(repo *usage.Repository, userID string) string { if repo == nil { return "" } emailAddr, err := repo.GetUserEmail(context.Background(), userID) if err != nil { return "" } return emailAddr } type UsageTracker struct { repo *usage.Repository } func NewUsageTracker(db *sql.DB) *UsageTracker { return &UsageTracker{ repo: usage.NewRepository(db), } } func (t *UsageTracker) RunMigrations(ctx context.Context) error { return t.repo.RunMigrations(ctx) } func (t *UsageTracker) TrackAPIRequest(c *fiber.Ctx) { userID := GetUserID(c) if userID == "" { return } tier := GetUserTier(c) if tier == "" { tier = "free" } go t.repo.IncrementAPIRequests(context.Background(), userID, tier) } func (t *UsageTracker) TrackLLMRequest(ctx context.Context, userID, tier string, tokens int) { if userID == "" { return } if tier == "" { tier = "free" } go t.repo.IncrementLLMUsage(ctx, userID, tier, tokens) } func (t *UsageTracker) TrackSearchRequest(c *fiber.Ctx) { userID := GetUserID(c) if userID == "" { return } tier := GetUserTier(c) if tier == "" { tier = "free" } go t.repo.IncrementSearchRequests(context.Background(), userID, tier) } func (t *UsageTracker) GetRepository() *usage.Repository { return t.repo } func UsageTracking(tracker *UsageTracker) fiber.Handler { return func(c *fiber.Ctx) error { if tracker != nil { tracker.TrackAPIRequest(c) } return c.Next() } }