package middleware import ( "context" "fmt" "strconv" "time" "github.com/gofiber/fiber/v2" "github.com/redis/go-redis/v9" ) type RedisRateLimiterConfig struct { RedisClient *redis.Client KeyPrefix string Max int Window time.Duration KeyFunc func(*fiber.Ctx) string SkipPaths []string } func RedisRateLimit(cfg RedisRateLimiterConfig) fiber.Handler { if cfg.KeyPrefix == "" { cfg.KeyPrefix = "ratelimit" } if cfg.Max == 0 { cfg.Max = 100 } if cfg.Window == 0 { cfg.Window = time.Minute } if cfg.KeyFunc == nil { cfg.KeyFunc = func(c *fiber.Ctx) string { return c.IP() } } skipMap := make(map[string]bool) for _, path := range cfg.SkipPaths { skipMap[path] = true } return func(c *fiber.Ctx) error { if skipMap[c.Path()] { return c.Next() } ctx := context.Background() key := fmt.Sprintf("%s:%s", cfg.KeyPrefix, cfg.KeyFunc(c)) pipe := cfg.RedisClient.Pipeline() incr := pipe.Incr(ctx, key) pipe.Expire(ctx, key, cfg.Window) _, err := pipe.Exec(ctx) if err != nil { return c.Next() } current := incr.Val() c.Set("X-RateLimit-Limit", strconv.Itoa(cfg.Max)) c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, cfg.Max-int(current)))) ttl, _ := cfg.RedisClient.TTL(ctx, key).Result() c.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)) if int(current) > cfg.Max { c.Set("Retry-After", strconv.FormatInt(int64(ttl.Seconds()), 10)) return c.Status(429).JSON(fiber.Map{ "error": "Too Many Requests", "retry_after": int64(ttl.Seconds()), }) } return c.Next() } } type SlidingWindowConfig struct { RedisClient *redis.Client KeyPrefix string Max int Window time.Duration KeyFunc func(*fiber.Ctx) string } func SlidingWindowRateLimit(cfg SlidingWindowConfig) fiber.Handler { if cfg.KeyPrefix == "" { cfg.KeyPrefix = "ratelimit:sliding" } if cfg.Max == 0 { cfg.Max = 100 } if cfg.Window == 0 { cfg.Window = time.Minute } if cfg.KeyFunc == nil { cfg.KeyFunc = func(c *fiber.Ctx) string { return c.IP() } } return func(c *fiber.Ctx) error { ctx := context.Background() key := fmt.Sprintf("%s:%s", cfg.KeyPrefix, cfg.KeyFunc(c)) now := time.Now() windowStart := now.Add(-cfg.Window).UnixMicro() pipe := cfg.RedisClient.Pipeline() pipe.ZRemRangeByScore(ctx, key, "0", strconv.FormatInt(windowStart, 10)) pipe.ZAdd(ctx, key, redis.Z{ Score: float64(now.UnixMicro()), Member: fmt.Sprintf("%d", now.UnixNano()), }) countCmd := pipe.ZCard(ctx, key) pipe.Expire(ctx, key, cfg.Window) _, err := pipe.Exec(ctx) if err != nil { return c.Next() } count := countCmd.Val() c.Set("X-RateLimit-Limit", strconv.Itoa(cfg.Max)) c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, cfg.Max-int(count)))) if int(count) > cfg.Max { return c.Status(429).JSON(fiber.Map{ "error": "Too Many Requests", "retry_after": int64(cfg.Window.Seconds()), }) } return c.Next() } } type TieredRateLimitConfig struct { RedisClient *redis.Client KeyPrefix string Tiers map[string]TierConfig DefaultTier string GetTierFunc func(*fiber.Ctx) string KeyFunc func(*fiber.Ctx) string } type TierConfig struct { Max int Window time.Duration } func TieredRateLimit(cfg TieredRateLimitConfig) fiber.Handler { if cfg.KeyPrefix == "" { cfg.KeyPrefix = "ratelimit:tiered" } if cfg.DefaultTier == "" { cfg.DefaultTier = "free" } if cfg.GetTierFunc == nil { cfg.GetTierFunc = func(c *fiber.Ctx) string { tier := GetUserTier(c) if tier == "" { return cfg.DefaultTier } return tier } } if cfg.KeyFunc == nil { cfg.KeyFunc = func(c *fiber.Ctx) string { userID := GetUserID(c) if userID != "" { return "user:" + userID } return "ip:" + c.IP() } } defaultTierCfg := TierConfig{Max: 60, Window: time.Minute} if _, ok := cfg.Tiers[cfg.DefaultTier]; !ok { cfg.Tiers[cfg.DefaultTier] = defaultTierCfg } return func(c *fiber.Ctx) error { ctx := context.Background() tier := cfg.GetTierFunc(c) tierCfg, ok := cfg.Tiers[tier] if !ok { tierCfg = cfg.Tiers[cfg.DefaultTier] } key := fmt.Sprintf("%s:%s:%s", cfg.KeyPrefix, tier, cfg.KeyFunc(c)) pipe := cfg.RedisClient.Pipeline() incr := pipe.Incr(ctx, key) pipe.Expire(ctx, key, tierCfg.Window) _, err := pipe.Exec(ctx) if err != nil { return c.Next() } current := incr.Val() c.Set("X-RateLimit-Tier", tier) c.Set("X-RateLimit-Limit", strconv.Itoa(tierCfg.Max)) c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, tierCfg.Max-int(current)))) if int(current) > tierCfg.Max { return c.Status(429).JSON(fiber.Map{ "error": "Too Many Requests", "tier": tier, "limit": tierCfg.Max, }) } return c.Next() } } func max(a, b int) int { if a > b { return a } return b }