Files
gooseek/backend/pkg/middleware/llm_limits.go
home 52134df4d1
Some checks failed
Build and Deploy GooSeek / build-and-deploy (push) Failing after 8m22s
feat: add email notification service with SMTP support
- Create pkg/email package (sender, templates, types)
- SMTP client with TLS, rate limiting, async sending
- HTML email templates with GooSeek branding
- Integrate welcome + password reset emails in auth-svc
- Add limit warning emails (80%/100%) in llm-svc middleware
- Add space invite endpoint with email notification in thread-svc
- Add GetUserEmail helper in JWT middleware
- Add SMTP config to .env, config.go, K8s configmap

Made-with: Cursor
2026-03-03 02:50:17 +03:00

180 lines
4.2 KiB
Go

package middleware
import (
"context"
"database/sql"
"log"
"github.com/gofiber/fiber/v2"
"github.com/gooseek/backend/internal/usage"
"github.com/gooseek/backend/pkg/email"
"github.com/gooseek/backend/pkg/metrics"
)
type LLMLimitsConfig struct {
UsageRepo *usage.Repository
EmailSender *email.Sender
}
func LLMLimits(config LLMLimitsConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
userID := GetUserID(c)
clientIP := c.IP()
if userID == "" {
metrics.RecordLLMUnauthorized("no_user_id", clientIP)
metrics.RecordSecurityEvent("unauthorized_llm_access", clientIP, "anonymous")
return c.Status(401).JSON(fiber.Map{
"error": "Authentication required",
})
}
tier := GetUserTier(c)
if tier == "" {
tier = "free"
}
if config.UsageRepo != nil {
allowed, reason := config.UsageRepo.CheckLLMLimits(c.Context(), userID, tier)
if !allowed {
limits := usage.GetLimits(tier)
if tier == "free" {
metrics.RecordFreeTierLimitExceeded(userID, reason)
metrics.RecordSecurityEvent("free_tier_limit_exceeded", clientIP, userID)
}
metrics.RecordRateLimitHit("llm-svc", clientIP, reason)
sendLimitEmail(config, userID, tier, limits.LLMRequestsPerDay, limits.LLMRequestsPerDay)
return c.Status(429).JSON(fiber.Map{
"error": reason,
"tier": tier,
"dailyLimit": limits.LLMRequestsPerDay,
"tokenLimit": limits.LLMTokensPerDay,
"upgradeUrl": "/settings/billing",
})
}
checkLimitWarning(config, userID, tier)
}
return c.Next()
}
}
func checkLimitWarning(config LLMLimitsConfig, userID, tier string) {
if config.UsageRepo == nil || config.EmailSender == nil || !config.EmailSender.IsConfigured() {
return
}
go func() {
todayUsage, err := config.UsageRepo.GetTodayUsage(context.Background(), userID)
if err != nil || todayUsage == nil {
return
}
limits := usage.GetLimits(tier)
if limits.LLMRequestsPerDay == 0 {
return
}
percentage := (todayUsage.LLMRequests * 100) / limits.LLMRequestsPerDay
if percentage >= 80 && percentage < 100 {
userEmail := getUserEmail(config.UsageRepo, userID)
if userEmail != "" {
if err := config.EmailSender.SendLimitWarning(userEmail, "", todayUsage.LLMRequests, limits.LLMRequestsPerDay, tier); err != nil {
log.Printf("[email] Limit warning send error: %v", err)
}
}
}
}()
}
func sendLimitEmail(config LLMLimitsConfig, userID, tier string, usageCount, limitCount int) {
if config.EmailSender == nil || !config.EmailSender.IsConfigured() {
return
}
go func() {
userEmail := getUserEmail(config.UsageRepo, userID)
if userEmail != "" {
if err := config.EmailSender.SendLimitWarning(userEmail, "", usageCount, limitCount, tier); err != nil {
log.Printf("[email] Limit exceeded email error: %v", err)
}
}
}()
}
func getUserEmail(repo *usage.Repository, userID string) string {
if repo == nil {
return ""
}
emailAddr, err := repo.GetUserEmail(context.Background(), userID)
if err != nil {
return ""
}
return emailAddr
}
type UsageTracker struct {
repo *usage.Repository
}
func NewUsageTracker(db *sql.DB) *UsageTracker {
return &UsageTracker{
repo: usage.NewRepository(db),
}
}
func (t *UsageTracker) RunMigrations(ctx context.Context) error {
return t.repo.RunMigrations(ctx)
}
func (t *UsageTracker) TrackAPIRequest(c *fiber.Ctx) {
userID := GetUserID(c)
if userID == "" {
return
}
tier := GetUserTier(c)
if tier == "" {
tier = "free"
}
go t.repo.IncrementAPIRequests(context.Background(), userID, tier)
}
func (t *UsageTracker) TrackLLMRequest(ctx context.Context, userID, tier string, tokens int) {
if userID == "" {
return
}
if tier == "" {
tier = "free"
}
go t.repo.IncrementLLMUsage(ctx, userID, tier, tokens)
}
func (t *UsageTracker) TrackSearchRequest(c *fiber.Ctx) {
userID := GetUserID(c)
if userID == "" {
return
}
tier := GetUserTier(c)
if tier == "" {
tier = "free"
}
go t.repo.IncrementSearchRequests(context.Background(), userID, tier)
}
func (t *UsageTracker) GetRepository() *usage.Repository {
return t.repo
}
func UsageTracking(tracker *UsageTracker) fiber.Handler {
return func(c *fiber.Ctx) error {
if tracker != nil {
tracker.TrackAPIRequest(c)
}
return c.Next()
}
}