package middleware import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" ) type JWTConfig struct { Secret string AuthSvcURL string SkipPaths []string AllowGuest bool CacheDuration time.Duration } type UserClaims struct { UserID string `json:"userId"` Email string `json:"email"` Role string `json:"role"` Tier string `json:"tier"` IsGuest bool `json:"isGuest"` ExpiresAt int64 `json:"exp"` } type contextKey string const UserContextKey contextKey = "user" func JWT(cfg JWTConfig) fiber.Handler { skipMap := make(map[string]bool) for _, path := range cfg.SkipPaths { skipMap[path] = true } if cfg.CacheDuration == 0 { cfg.CacheDuration = 5 * time.Minute } return func(c *fiber.Ctx) error { if skipMap[c.Path()] { return c.Next() } if strings.HasPrefix(c.Path(), "/health") || strings.HasPrefix(c.Path(), "/ready") { return c.Next() } token := ExtractToken(c) if token == "" { if cfg.AllowGuest { c.Locals(string(UserContextKey), &UserClaims{ IsGuest: true, Role: "guest", Tier: "free", }) return c.Next() } return c.Status(401).JSON(fiber.Map{ "error": "Unauthorized", "message": "Missing authorization token", }) } var claims *UserClaims var err error if cfg.Secret != "" { claims, err = validateLocalJWT(token, cfg.Secret) } else if cfg.AuthSvcURL != "" { claims, err = validateWithAuthService(c.Context(), token, cfg.AuthSvcURL) } else { return c.Status(500).JSON(fiber.Map{ "error": "Configuration Error", "message": "JWT validation not configured", }) } if err != nil { return c.Status(401).JSON(fiber.Map{ "error": "Unauthorized", "message": err.Error(), }) } c.Locals(string(UserContextKey), claims) c.Locals("userId", claims.UserID) c.Locals("userRole", claims.Role) c.Locals("userTier", claims.Tier) return c.Next() } } func validateLocalJWT(tokenString, secret string) (*UserClaims, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(secret), nil }) if err != nil { return nil, fmt.Errorf("invalid token: %w", err) } if !token.Valid { return nil, fmt.Errorf("token is not valid") } mapClaims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, fmt.Errorf("invalid claims format") } claims := &UserClaims{} if v, ok := mapClaims["userId"].(string); ok { claims.UserID = v } else if v, ok := mapClaims["sub"].(string); ok { claims.UserID = v } if v, ok := mapClaims["email"].(string); ok { claims.Email = v } if v, ok := mapClaims["role"].(string); ok { claims.Role = v } else { claims.Role = "user" } if v, ok := mapClaims["tier"].(string); ok { claims.Tier = v } else { claims.Tier = "free" } if v, ok := mapClaims["exp"].(float64); ok { claims.ExpiresAt = int64(v) } return claims, nil } func validateWithAuthService(ctx context.Context, token, authURL string) (*UserClaims, error) { reqURL := strings.TrimSuffix(authURL, "/") + "/api/v1/auth/validate" req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+token) client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("auth service unavailable: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token validation failed: status %d", resp.StatusCode) } var result struct { Valid bool `json:"valid"` User UserClaims `json:"user"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode auth response: %w", err) } if !result.Valid { return nil, fmt.Errorf("token is not valid") } return &result.User, nil } func GetUser(c *fiber.Ctx) *UserClaims { user, ok := c.Locals(string(UserContextKey)).(*UserClaims) if !ok { return nil } return user } func GetUserID(c *fiber.Ctx) string { user := GetUser(c) if user == nil { return "" } return user.UserID } func GetUserTier(c *fiber.Ctx) string { user := GetUser(c) if user == nil { return "free" } return user.Tier } func RequireAuth() fiber.Handler { return func(c *fiber.Ctx) error { user := GetUser(c) if user == nil || user.IsGuest { return c.Status(401).JSON(fiber.Map{ "error": "Unauthorized", "message": "Authentication required", }) } return c.Next() } } func RequireRole(roles ...string) fiber.Handler { roleMap := make(map[string]bool) for _, r := range roles { roleMap[r] = true } return func(c *fiber.Ctx) error { user := GetUser(c) if user == nil { return c.Status(401).JSON(fiber.Map{ "error": "Unauthorized", "message": "Authentication required", }) } if !roleMap[user.Role] { return c.Status(403).JSON(fiber.Map{ "error": "Forbidden", "message": "Insufficient permissions", }) } return c.Next() } } func RequireTier(tiers ...string) fiber.Handler { tierMap := make(map[string]bool) for _, t := range tiers { tierMap[t] = true } return func(c *fiber.Ctx) error { user := GetUser(c) if user == nil { return c.Status(401).JSON(fiber.Map{ "error": "Unauthorized", "message": "Authentication required", }) } if !tierMap[user.Tier] { return c.Status(403).JSON(fiber.Map{ "error": "Forbidden", "message": "This feature requires a higher tier subscription", "current": user.Tier, "required": tiers, }) } return c.Next() } }