package middleware import ( "sync" "time" "github.com/gofiber/fiber/v2" ) type TierConfig struct { Max int Window time.Duration } type TieredRateLimitConfig struct { Tiers map[string]TierConfig DefaultTier string KeyFunc func(*fiber.Ctx) string GetTierFunc func(*fiber.Ctx) string } type tieredRateLimiter struct { requests map[string][]time.Time mu sync.RWMutex tiers map[string]TierConfig } func newTieredRateLimiter(tiers map[string]TierConfig) *tieredRateLimiter { rl := &tieredRateLimiter{ requests: make(map[string][]time.Time), tiers: tiers, } go rl.cleanup() return rl } func (rl *tieredRateLimiter) cleanup() { ticker := time.NewTicker(time.Minute) for range ticker.C { rl.mu.Lock() now := time.Now() for key, times := range rl.requests { var valid []time.Time for _, t := range times { if now.Sub(t) < 5*time.Minute { valid = append(valid, t) } } if len(valid) == 0 { delete(rl.requests, key) } else { rl.requests[key] = valid } } rl.mu.Unlock() } } func (rl *tieredRateLimiter) allow(key string, tier string) (bool, int, int) { rl.mu.Lock() defer rl.mu.Unlock() cfg, ok := rl.tiers[tier] if !ok { cfg = rl.tiers["free"] } now := time.Now() windowStart := now.Add(-cfg.Window) times := rl.requests[key] var valid []time.Time for _, t := range times { if t.After(windowStart) { valid = append(valid, t) } } remaining := cfg.Max - len(valid) if remaining <= 0 { rl.requests[key] = valid return false, 0, cfg.Max } rl.requests[key] = append(valid, now) return true, remaining - 1, cfg.Max } func TieredRateLimit(config TieredRateLimitConfig) fiber.Handler { if config.Tiers == nil { config.Tiers = map[string]TierConfig{ "free": {Max: 60, Window: time.Minute}, "pro": {Max: 300, Window: time.Minute}, "business": {Max: 1000, Window: time.Minute}, } } if config.DefaultTier == "" { config.DefaultTier = "free" } if config.KeyFunc == nil { config.KeyFunc = func(c *fiber.Ctx) string { userID := GetUserID(c) if userID != "" { return "user:" + userID } return "ip:" + c.IP() } } if config.GetTierFunc == nil { config.GetTierFunc = func(c *fiber.Ctx) string { tier := GetUserTier(c) if tier == "" { return config.DefaultTier } return tier } } limiter := newTieredRateLimiter(config.Tiers) return func(c *fiber.Ctx) error { key := config.KeyFunc(c) tier := config.GetTierFunc(c) allowed, remaining, limit := limiter.allow(key, tier) c.Set("X-RateLimit-Limit", formatInt(limit)) c.Set("X-RateLimit-Remaining", formatInt(remaining)) c.Set("X-RateLimit-Tier", tier) if !allowed { c.Set("Retry-After", "60") return c.Status(429).JSON(fiber.Map{ "error": "Rate limit exceeded", "tier": tier, "limit": limit, "retryAfter": 60, }) } return c.Next() } } func formatInt(n int) string { if n < 0 { n = 0 } s := "" if n == 0 { return "0" } for n > 0 { s = string(rune('0'+n%10)) + s n /= 10 } return s }