feat: Go backend, enhanced search, new widgets, Docker deploy
Major changes: - Add Go backend (backend/) with microservices architecture - Enhanced master-agents-svc: reranker, content-classifier, stealth-crawler, proxy-manager, media-search, fastClassifier, language detection - New web-svc widgets: KnowledgeCard, ProductCard, ProfileCard, VideoCard, UnifiedCard, CardGallery, InlineImageGallery, SourcesPanel, RelatedQuestions - Improved discover-svc with discover-db integration - Docker deployment improvements (Caddyfile, vendor.sh, BUILD.md) - Library-svc: project_id schema migration - Remove deprecated finance-svc and travel-svc - Localization improvements across services Made-with: Cursor
This commit is contained in:
183
backend/pkg/cache/redis.go
vendored
Normal file
183
backend/pkg/cache/redis.go
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewRedisCache(redisURL, prefix string) (*RedisCache, error) {
|
||||
opts, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := redis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
prefix: prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.client.Set(ctx, c.prefix+":"+key, data, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
data, err := c.client.Get(ctx, c.prefix+":"+key).Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, dest)
|
||||
}
|
||||
|
||||
func (c *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
n, err := c.client.Exists(ctx, c.prefix+":"+key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) Delete(ctx context.Context, key string) error {
|
||||
return c.client.Del(ctx, c.prefix+":"+key).Err()
|
||||
}
|
||||
|
||||
func (c *RedisCache) SetJSON(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
return c.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
func (c *RedisCache) GetJSON(ctx context.Context, key string, dest interface{}) error {
|
||||
return c.Get(ctx, key, dest)
|
||||
}
|
||||
|
||||
type CacheKey string
|
||||
|
||||
const (
|
||||
KeySearchResults CacheKey = "search"
|
||||
KeyArticleSummary CacheKey = "summary"
|
||||
KeyDigest CacheKey = "digest"
|
||||
KeyChatResponse CacheKey = "chat"
|
||||
)
|
||||
|
||||
func HashKey(parts ...string) string {
|
||||
combined := ""
|
||||
for _, p := range parts {
|
||||
combined += p + ":"
|
||||
}
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:16])
|
||||
}
|
||||
|
||||
func (c *RedisCache) CacheSearch(ctx context.Context, query string, results interface{}, ttl time.Duration) error {
|
||||
key := string(KeySearchResults) + ":" + HashKey(query)
|
||||
return c.Set(ctx, key, results, ttl)
|
||||
}
|
||||
|
||||
func (c *RedisCache) GetCachedSearch(ctx context.Context, query string, dest interface{}) error {
|
||||
key := string(KeySearchResults) + ":" + HashKey(query)
|
||||
return c.Get(ctx, key, dest)
|
||||
}
|
||||
|
||||
func (c *RedisCache) CacheArticleSummary(ctx context.Context, url string, events []string, ttl time.Duration) error {
|
||||
key := string(KeyArticleSummary) + ":" + HashKey(url)
|
||||
return c.Set(ctx, key, events, ttl)
|
||||
}
|
||||
|
||||
func (c *RedisCache) GetCachedArticleSummary(ctx context.Context, url string) ([]string, error) {
|
||||
key := string(KeyArticleSummary) + ":" + HashKey(url)
|
||||
var events []string
|
||||
if err := c.Get(ctx, key, &events); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) CacheDigest(ctx context.Context, topic, region, title string, digest interface{}, ttl time.Duration) error {
|
||||
key := string(KeyDigest) + ":" + HashKey(topic, region, title)
|
||||
return c.Set(ctx, key, digest, ttl)
|
||||
}
|
||||
|
||||
func (c *RedisCache) GetCachedDigest(ctx context.Context, topic, region, title string, dest interface{}) error {
|
||||
key := string(KeyDigest) + ":" + HashKey(topic, region, title)
|
||||
return c.Get(ctx, key, dest)
|
||||
}
|
||||
|
||||
type MemoryCache struct {
|
||||
data map[string]cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
return &MemoryCache{
|
||||
data: make(map[string]cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||
c.data[key] = cacheEntry{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MemoryCache) Get(key string) (interface{}, bool) {
|
||||
entry, ok := c.data[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
delete(c.data, key)
|
||||
return nil, false
|
||||
}
|
||||
return entry.value, true
|
||||
}
|
||||
|
||||
func (c *MemoryCache) Delete(key string) {
|
||||
delete(c.data, key)
|
||||
}
|
||||
|
||||
func (c *MemoryCache) Clear() {
|
||||
c.data = make(map[string]cacheEntry)
|
||||
}
|
||||
|
||||
func (c *MemoryCache) Cleanup() int {
|
||||
count := 0
|
||||
now := time.Now()
|
||||
for k, v := range c.data {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(c.data, k)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
173
backend/pkg/config/config.go
Normal file
173
backend/pkg/config/config.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Environment string
|
||||
LogLevel string
|
||||
|
||||
// Service ports
|
||||
APIGatewayPort int
|
||||
ChatSvcPort int
|
||||
AgentSvcPort int
|
||||
SearchSvcPort int
|
||||
LLMSvcPort int
|
||||
ScraperSvcPort int
|
||||
|
||||
// Service URLs
|
||||
ChatSvcURL string
|
||||
AgentSvcURL string
|
||||
SearchSvcURL string
|
||||
LLMSvcURL string
|
||||
ScraperSvcURL string
|
||||
MemorySvcURL string
|
||||
LibrarySvcURL string
|
||||
|
||||
// External services
|
||||
SearXNGURL string
|
||||
SearXNGFallbackURL []string
|
||||
Crawl4AIURL string
|
||||
RedisURL string
|
||||
DatabaseURL string
|
||||
DiscoverSvcURL string
|
||||
CollectionSvcURL string
|
||||
FileSvcURL string
|
||||
ThreadSvcURL string
|
||||
ComputerSvcURL string
|
||||
FinanceHeatmapURL string
|
||||
LearningSvcURL string
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
AuthSvcURL string
|
||||
|
||||
// LLM defaults
|
||||
DefaultLLMProvider string
|
||||
DefaultLLMModel string
|
||||
OpenAIAPIKey string
|
||||
AnthropicAPIKey string
|
||||
GeminiAPIKey string
|
||||
|
||||
// Timeweb Cloud AI
|
||||
TimewebAPIBaseURL string
|
||||
TimewebAgentAccessID string
|
||||
TimewebAPIKey string
|
||||
TimewebProxySource string
|
||||
|
||||
// Timeouts
|
||||
HTTPTimeout time.Duration
|
||||
LLMTimeout time.Duration
|
||||
ScrapeTimeout time.Duration
|
||||
SearchTimeout time.Duration
|
||||
|
||||
// CORS
|
||||
AllowedOrigins []string
|
||||
}
|
||||
|
||||
var cfg *Config
|
||||
|
||||
func Load() (*Config, error) {
|
||||
_ = godotenv.Load()
|
||||
_ = godotenv.Load("../.env")
|
||||
_ = godotenv.Load("../../.env")
|
||||
|
||||
cfg = &Config{
|
||||
Environment: getEnv("ENVIRONMENT", "production"),
|
||||
LogLevel: getEnv("LOG_LEVEL", "info"),
|
||||
|
||||
APIGatewayPort: getEnvInt("API_GATEWAY_PORT", 3015),
|
||||
ChatSvcPort: getEnvInt("CHAT_SVC_PORT", 3005),
|
||||
AgentSvcPort: getEnvInt("AGENT_SVC_PORT", 3018),
|
||||
SearchSvcPort: getEnvInt("SEARCH_SVC_PORT", 3001),
|
||||
LLMSvcPort: getEnvInt("LLM_SVC_PORT", 3020),
|
||||
ScraperSvcPort: getEnvInt("SCRAPER_SVC_PORT", 3021),
|
||||
|
||||
ChatSvcURL: getEnv("CHAT_SVC_URL", "http://localhost:3005"),
|
||||
AgentSvcURL: getEnv("MASTER_AGENTS_SVC_URL", "http://localhost:3018"),
|
||||
SearchSvcURL: getEnv("SEARCH_SVC_URL", "http://localhost:3001"),
|
||||
LLMSvcURL: getEnv("LLM_SVC_URL", "http://localhost:3020"),
|
||||
ScraperSvcURL: getEnv("SCRAPER_SVC_URL", "http://localhost:3021"),
|
||||
MemorySvcURL: getEnv("MEMORY_SVC_URL", ""),
|
||||
LibrarySvcURL: getEnv("LIBRARY_SVC_URL", "http://localhost:3009"),
|
||||
|
||||
SearXNGURL: getEnv("SEARXNG_URL", "http://searxng:8080"),
|
||||
SearXNGFallbackURL: strings.Split(getEnv("SEARXNG_FALLBACK_URL", ""), ","),
|
||||
Crawl4AIURL: getEnv("CRAWL4AI_URL", "http://crawl4ai:11235"),
|
||||
RedisURL: getEnv("REDIS_URL", "redis://localhost:6379"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", ""),
|
||||
DiscoverSvcURL: getEnv("DISCOVER_SVC_URL", "http://localhost:3002"),
|
||||
CollectionSvcURL: getEnv("COLLECTION_SVC_URL", "http://localhost:3025"),
|
||||
FileSvcURL: getEnv("FILE_SVC_URL", "http://localhost:3026"),
|
||||
ThreadSvcURL: getEnv("THREAD_SVC_URL", "http://localhost:3027"),
|
||||
ComputerSvcURL: getEnv("COMPUTER_SVC_URL", "http://localhost:3030"),
|
||||
FinanceHeatmapURL: getEnv("FINANCE_HEATMAP_SVC_URL", "http://localhost:3033"),
|
||||
LearningSvcURL: getEnv("LEARNING_SVC_URL", "http://localhost:3034"),
|
||||
|
||||
JWTSecret: getEnv("JWT_SECRET", ""),
|
||||
AuthSvcURL: getEnv("AUTH_SVC_URL", ""),
|
||||
|
||||
DefaultLLMProvider: getEnv("DEFAULT_LLM_PROVIDER", "openai"),
|
||||
DefaultLLMModel: getEnv("DEFAULT_LLM_MODEL", "gpt-4o-mini"),
|
||||
OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
|
||||
AnthropicAPIKey: getEnv("ANTHROPIC_API_KEY", ""),
|
||||
GeminiAPIKey: getEnv("GEMINI_API_KEY", ""),
|
||||
|
||||
TimewebAPIBaseURL: getEnv("TIMEWEB_API_BASE_URL", "https://api.timeweb.cloud"),
|
||||
TimewebAgentAccessID: getEnv("TIMEWEB_AGENT_ACCESS_ID", ""),
|
||||
TimewebAPIKey: getEnv("TIMEWEB_API_KEY", ""),
|
||||
TimewebProxySource: getEnv("TIMEWEB_X_PROXY_SOURCE", "gooseek"),
|
||||
|
||||
HTTPTimeout: time.Duration(getEnvInt("HTTP_TIMEOUT_MS", 60000)) * time.Millisecond,
|
||||
LLMTimeout: time.Duration(getEnvInt("LLM_TIMEOUT_MS", 120000)) * time.Millisecond,
|
||||
ScrapeTimeout: time.Duration(getEnvInt("SCRAPE_TIMEOUT_MS", 25000)) * time.Millisecond,
|
||||
SearchTimeout: time.Duration(getEnvInt("SEARCH_TIMEOUT_MS", 10000)) * time.Millisecond,
|
||||
|
||||
AllowedOrigins: parseOrigins(getEnv("ALLOWED_ORIGINS", "*")),
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func Get() *Config {
|
||||
if cfg == nil {
|
||||
cfg, _ = Load()
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if i, err := strconv.Atoi(value); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func parseOrigins(s string) []string {
|
||||
if s == "*" {
|
||||
return []string{"*"}
|
||||
}
|
||||
origins := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(origins))
|
||||
for _, o := range origins {
|
||||
if trimmed := strings.TrimSpace(o); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
50
backend/pkg/middleware/auth.go
Normal file
50
backend/pkg/middleware/auth.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type AuthConfig struct {
|
||||
RequireAuth bool
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
func Auth(config AuthConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if !config.RequireAuth {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
path := c.Path()
|
||||
for _, skip := range config.SkipPaths {
|
||||
if strings.HasPrefix(path, skip) {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
auth := c.Get("Authorization")
|
||||
if auth == "" {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Invalid authorization format",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractToken(c *fiber.Ctx) string {
|
||||
auth := c.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
280
backend/pkg/middleware/jwt.go
Normal file
280
backend/pkg/middleware/jwt.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string
|
||||
AuthSvcURL string
|
||||
SkipPaths []string
|
||||
AllowGuest bool
|
||||
CacheDuration time.Duration
|
||||
}
|
||||
|
||||
type UserClaims struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier"`
|
||||
IsGuest bool `json:"isGuest"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const UserContextKey contextKey = "user"
|
||||
|
||||
func JWT(cfg JWTConfig) fiber.Handler {
|
||||
skipMap := make(map[string]bool)
|
||||
for _, path := range cfg.SkipPaths {
|
||||
skipMap[path] = true
|
||||
}
|
||||
|
||||
if cfg.CacheDuration == 0 {
|
||||
cfg.CacheDuration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if skipMap[c.Path()] {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Path(), "/health") || strings.HasPrefix(c.Path(), "/ready") {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
token := ExtractToken(c)
|
||||
|
||||
if token == "" {
|
||||
if cfg.AllowGuest {
|
||||
c.Locals(string(UserContextKey), &UserClaims{
|
||||
IsGuest: true,
|
||||
Role: "guest",
|
||||
Tier: "free",
|
||||
})
|
||||
return c.Next()
|
||||
}
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
"message": "Missing authorization token",
|
||||
})
|
||||
}
|
||||
|
||||
var claims *UserClaims
|
||||
var err error
|
||||
|
||||
if cfg.Secret != "" {
|
||||
claims, err = validateLocalJWT(token, cfg.Secret)
|
||||
} else if cfg.AuthSvcURL != "" {
|
||||
claims, err = validateWithAuthService(c.Context(), token, cfg.AuthSvcURL)
|
||||
} else {
|
||||
return c.Status(500).JSON(fiber.Map{
|
||||
"error": "Configuration Error",
|
||||
"message": "JWT validation not configured",
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.Locals(string(UserContextKey), claims)
|
||||
c.Locals("userId", claims.UserID)
|
||||
c.Locals("userRole", claims.Role)
|
||||
c.Locals("userTier", claims.Tier)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func validateLocalJWT(tokenString, secret string) (*UserClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("token is not valid")
|
||||
}
|
||||
|
||||
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims format")
|
||||
}
|
||||
|
||||
claims := &UserClaims{}
|
||||
|
||||
if v, ok := mapClaims["userId"].(string); ok {
|
||||
claims.UserID = v
|
||||
} else if v, ok := mapClaims["sub"].(string); ok {
|
||||
claims.UserID = v
|
||||
}
|
||||
|
||||
if v, ok := mapClaims["email"].(string); ok {
|
||||
claims.Email = v
|
||||
}
|
||||
|
||||
if v, ok := mapClaims["role"].(string); ok {
|
||||
claims.Role = v
|
||||
} else {
|
||||
claims.Role = "user"
|
||||
}
|
||||
|
||||
if v, ok := mapClaims["tier"].(string); ok {
|
||||
claims.Tier = v
|
||||
} else {
|
||||
claims.Tier = "free"
|
||||
}
|
||||
|
||||
if v, ok := mapClaims["exp"].(float64); ok {
|
||||
claims.ExpiresAt = int64(v)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateWithAuthService(ctx context.Context, token, authURL string) (*UserClaims, error) {
|
||||
reqURL := strings.TrimSuffix(authURL, "/") + "/api/v1/auth/validate"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth service unavailable: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token validation failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Valid bool `json:"valid"`
|
||||
User UserClaims `json:"user"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode auth response: %w", err)
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
return nil, fmt.Errorf("token is not valid")
|
||||
}
|
||||
|
||||
return &result.User, nil
|
||||
}
|
||||
|
||||
func GetUser(c *fiber.Ctx) *UserClaims {
|
||||
user, ok := c.Locals(string(UserContextKey)).(*UserClaims)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func GetUserID(c *fiber.Ctx) string {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
return user.UserID
|
||||
}
|
||||
|
||||
func GetUserTier(c *fiber.Ctx) string {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return "free"
|
||||
}
|
||||
return user.Tier
|
||||
}
|
||||
|
||||
func RequireAuth() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
user := GetUser(c)
|
||||
if user == nil || user.IsGuest {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required",
|
||||
})
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequireRole(roles ...string) fiber.Handler {
|
||||
roleMap := make(map[string]bool)
|
||||
for _, r := range roles {
|
||||
roleMap[r] = true
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required",
|
||||
})
|
||||
}
|
||||
|
||||
if !roleMap[user.Role] {
|
||||
return c.Status(403).JSON(fiber.Map{
|
||||
"error": "Forbidden",
|
||||
"message": "Insufficient permissions",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequireTier(tiers ...string) fiber.Handler {
|
||||
tierMap := make(map[string]bool)
|
||||
for _, t := range tiers {
|
||||
tierMap[t] = true
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return c.Status(401).JSON(fiber.Map{
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required",
|
||||
})
|
||||
}
|
||||
|
||||
if !tierMap[user.Tier] {
|
||||
return c.Status(403).JSON(fiber.Map{
|
||||
"error": "Forbidden",
|
||||
"message": "This feature requires a higher tier subscription",
|
||||
"current": user.Tier,
|
||||
"required": tiers,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
49
backend/pkg/middleware/logging.go
Normal file
49
backend/pkg/middleware/logging.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type LoggingConfig struct {
|
||||
Logger *zap.Logger
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
func Logging(config LoggingConfig) fiber.Handler {
|
||||
logger := config.Logger
|
||||
if logger == nil {
|
||||
logger, _ = zap.NewProduction()
|
||||
}
|
||||
|
||||
skipPaths := make(map[string]bool)
|
||||
for _, path := range config.SkipPaths {
|
||||
skipPaths[path] = true
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
path := c.Path()
|
||||
if skipPaths[path] {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
err := c.Next()
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
logger.Info("request",
|
||||
zap.String("method", c.Method()),
|
||||
zap.String("path", path),
|
||||
zap.Int("status", c.Response().StatusCode()),
|
||||
zap.Duration("latency", duration),
|
||||
zap.String("ip", c.IP()),
|
||||
zap.String("user-agent", c.Get("User-Agent")),
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
106
backend/pkg/middleware/ratelimit.go
Normal file
106
backend/pkg/middleware/ratelimit.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Max int
|
||||
WindowSecs int
|
||||
KeyFunc func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
requests map[string][]time.Time
|
||||
mu sync.RWMutex
|
||||
max int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func newRateLimiter(max int, windowSecs int) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
max: max,
|
||||
window: time.Duration(windowSecs) * time.Second,
|
||||
}
|
||||
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
var valid []time.Time
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < rl.window {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
if len(valid) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = valid
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-rl.window)
|
||||
|
||||
times := rl.requests[key]
|
||||
var valid []time.Time
|
||||
for _, t := range times {
|
||||
if t.After(windowStart) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
|
||||
if len(valid) >= rl.max {
|
||||
rl.requests[key] = valid
|
||||
return false
|
||||
}
|
||||
|
||||
rl.requests[key] = append(valid, now)
|
||||
return true
|
||||
}
|
||||
|
||||
func RateLimit(config RateLimitConfig) fiber.Handler {
|
||||
if config.Max == 0 {
|
||||
config.Max = 100
|
||||
}
|
||||
if config.WindowSecs == 0 {
|
||||
config.WindowSecs = 60
|
||||
}
|
||||
if config.KeyFunc == nil {
|
||||
config.KeyFunc = func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
}
|
||||
}
|
||||
|
||||
limiter := newRateLimiter(config.Max, config.WindowSecs)
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
key := config.KeyFunc(c)
|
||||
|
||||
if !limiter.allow(key) {
|
||||
return c.Status(429).JSON(fiber.Map{
|
||||
"error": "Too many requests",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
213
backend/pkg/middleware/ratelimit_redis.go
Normal file
213
backend/pkg/middleware/ratelimit_redis.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type RedisRateLimiterConfig struct {
|
||||
RedisClient *redis.Client
|
||||
KeyPrefix string
|
||||
Max int
|
||||
Window time.Duration
|
||||
KeyFunc func(*fiber.Ctx) string
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
func RedisRateLimit(cfg RedisRateLimiterConfig) fiber.Handler {
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "ratelimit"
|
||||
}
|
||||
if cfg.Max == 0 {
|
||||
cfg.Max = 100
|
||||
}
|
||||
if cfg.Window == 0 {
|
||||
cfg.Window = time.Minute
|
||||
}
|
||||
if cfg.KeyFunc == nil {
|
||||
cfg.KeyFunc = func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
}
|
||||
}
|
||||
|
||||
skipMap := make(map[string]bool)
|
||||
for _, path := range cfg.SkipPaths {
|
||||
skipMap[path] = true
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if skipMap[c.Path()] {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
key := fmt.Sprintf("%s:%s", cfg.KeyPrefix, cfg.KeyFunc(c))
|
||||
|
||||
pipe := cfg.RedisClient.Pipeline()
|
||||
incr := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, cfg.Window)
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
current := incr.Val()
|
||||
|
||||
c.Set("X-RateLimit-Limit", strconv.Itoa(cfg.Max))
|
||||
c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, cfg.Max-int(current))))
|
||||
|
||||
ttl, _ := cfg.RedisClient.TTL(ctx, key).Result()
|
||||
c.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10))
|
||||
|
||||
if int(current) > cfg.Max {
|
||||
c.Set("Retry-After", strconv.FormatInt(int64(ttl.Seconds()), 10))
|
||||
return c.Status(429).JSON(fiber.Map{
|
||||
"error": "Too Many Requests",
|
||||
"retry_after": int64(ttl.Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type SlidingWindowConfig struct {
|
||||
RedisClient *redis.Client
|
||||
KeyPrefix string
|
||||
Max int
|
||||
Window time.Duration
|
||||
KeyFunc func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
func SlidingWindowRateLimit(cfg SlidingWindowConfig) fiber.Handler {
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "ratelimit:sliding"
|
||||
}
|
||||
if cfg.Max == 0 {
|
||||
cfg.Max = 100
|
||||
}
|
||||
if cfg.Window == 0 {
|
||||
cfg.Window = time.Minute
|
||||
}
|
||||
if cfg.KeyFunc == nil {
|
||||
cfg.KeyFunc = func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx := context.Background()
|
||||
key := fmt.Sprintf("%s:%s", cfg.KeyPrefix, cfg.KeyFunc(c))
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-cfg.Window).UnixMicro()
|
||||
|
||||
pipe := cfg.RedisClient.Pipeline()
|
||||
|
||||
pipe.ZRemRangeByScore(ctx, key, "0", strconv.FormatInt(windowStart, 10))
|
||||
|
||||
pipe.ZAdd(ctx, key, redis.Z{
|
||||
Score: float64(now.UnixMicro()),
|
||||
Member: fmt.Sprintf("%d", now.UnixNano()),
|
||||
})
|
||||
|
||||
countCmd := pipe.ZCard(ctx, key)
|
||||
|
||||
pipe.Expire(ctx, key, cfg.Window)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
count := countCmd.Val()
|
||||
|
||||
c.Set("X-RateLimit-Limit", strconv.Itoa(cfg.Max))
|
||||
c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, cfg.Max-int(count))))
|
||||
|
||||
if int(count) > cfg.Max {
|
||||
return c.Status(429).JSON(fiber.Map{
|
||||
"error": "Too Many Requests",
|
||||
"retry_after": int64(cfg.Window.Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type TieredRateLimitConfig struct {
|
||||
RedisClient *redis.Client
|
||||
KeyPrefix string
|
||||
Tiers map[string]TierConfig
|
||||
GetTierFunc func(*fiber.Ctx) string
|
||||
KeyFunc func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
type TierConfig struct {
|
||||
Max int
|
||||
Window time.Duration
|
||||
}
|
||||
|
||||
func TieredRateLimit(cfg TieredRateLimitConfig) fiber.Handler {
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "ratelimit:tiered"
|
||||
}
|
||||
if cfg.GetTierFunc == nil {
|
||||
cfg.GetTierFunc = func(c *fiber.Ctx) string { return "default" }
|
||||
}
|
||||
if cfg.KeyFunc == nil {
|
||||
cfg.KeyFunc = func(c *fiber.Ctx) string { return c.IP() }
|
||||
}
|
||||
|
||||
defaultTier := TierConfig{Max: 60, Window: time.Minute}
|
||||
if _, ok := cfg.Tiers["default"]; !ok {
|
||||
cfg.Tiers["default"] = defaultTier
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx := context.Background()
|
||||
tier := cfg.GetTierFunc(c)
|
||||
tierCfg, ok := cfg.Tiers[tier]
|
||||
if !ok {
|
||||
tierCfg = cfg.Tiers["default"]
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s:%s:%s", cfg.KeyPrefix, tier, cfg.KeyFunc(c))
|
||||
|
||||
pipe := cfg.RedisClient.Pipeline()
|
||||
incr := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, tierCfg.Window)
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
current := incr.Val()
|
||||
|
||||
c.Set("X-RateLimit-Tier", tier)
|
||||
c.Set("X-RateLimit-Limit", strconv.Itoa(tierCfg.Max))
|
||||
c.Set("X-RateLimit-Remaining", strconv.Itoa(max(0, tierCfg.Max-int(current))))
|
||||
|
||||
if int(current) > tierCfg.Max {
|
||||
return c.Status(429).JSON(fiber.Map{
|
||||
"error": "Too Many Requests",
|
||||
"tier": tier,
|
||||
"limit": tierCfg.Max,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
113
backend/pkg/ndjson/writer.go
Normal file
113
backend/pkg/ndjson/writer.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package ndjson
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
buf *bufio.Writer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
buf: bufio.NewWriter(w),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(v interface{}) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.buf.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.buf.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.buf.Flush()
|
||||
}
|
||||
|
||||
func (w *Writer) WriteRaw(data []byte) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if _, err := w.buf.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.buf.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.buf.Flush()
|
||||
}
|
||||
|
||||
func (w *Writer) Flush() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.buf.Flush()
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Block interface{} `json:"block,omitempty"`
|
||||
BlockID string `json:"blockId,omitempty"`
|
||||
Chunk string `json:"chunk,omitempty"`
|
||||
Patch interface{} `json:"patch,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func WriteBlock(w *Writer, block interface{}) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "block",
|
||||
Block: block,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteTextChunk(w *Writer, blockID, chunk string) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "textChunk",
|
||||
BlockID: blockID,
|
||||
Chunk: chunk,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteUpdateBlock(w *Writer, blockID string, patch interface{}) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "updateBlock",
|
||||
BlockID: blockID,
|
||||
Patch: patch,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteResearchComplete(w *Writer) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "researchComplete",
|
||||
})
|
||||
}
|
||||
|
||||
func WriteMessageEnd(w *Writer) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "messageEnd",
|
||||
})
|
||||
}
|
||||
|
||||
func WriteError(w *Writer, err error) error {
|
||||
return w.Write(StreamEvent{
|
||||
Type: "error",
|
||||
Data: err.Error(),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user