Some checks failed
Build and Deploy GooSeek / build-and-deploy (push) Failing after 8m22s
- Create pkg/email package (sender, templates, types) - SMTP client with TLS, rate limiting, async sending - HTML email templates with GooSeek branding - Integrate welcome + password reset emails in auth-svc - Add limit warning emails (80%/100%) in llm-svc middleware - Add space invite endpoint with email notification in thread-svc - Add GetUserEmail helper in JWT middleware - Add SMTP config to .env, config.go, K8s configmap Made-with: Cursor
180 lines
4.2 KiB
Go
180 lines
4.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"log"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gooseek/backend/internal/usage"
|
|
"github.com/gooseek/backend/pkg/email"
|
|
"github.com/gooseek/backend/pkg/metrics"
|
|
)
|
|
|
|
type LLMLimitsConfig struct {
|
|
UsageRepo *usage.Repository
|
|
EmailSender *email.Sender
|
|
}
|
|
|
|
func LLMLimits(config LLMLimitsConfig) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
userID := GetUserID(c)
|
|
clientIP := c.IP()
|
|
|
|
if userID == "" {
|
|
metrics.RecordLLMUnauthorized("no_user_id", clientIP)
|
|
metrics.RecordSecurityEvent("unauthorized_llm_access", clientIP, "anonymous")
|
|
return c.Status(401).JSON(fiber.Map{
|
|
"error": "Authentication required",
|
|
})
|
|
}
|
|
|
|
tier := GetUserTier(c)
|
|
if tier == "" {
|
|
tier = "free"
|
|
}
|
|
|
|
if config.UsageRepo != nil {
|
|
allowed, reason := config.UsageRepo.CheckLLMLimits(c.Context(), userID, tier)
|
|
if !allowed {
|
|
limits := usage.GetLimits(tier)
|
|
|
|
if tier == "free" {
|
|
metrics.RecordFreeTierLimitExceeded(userID, reason)
|
|
metrics.RecordSecurityEvent("free_tier_limit_exceeded", clientIP, userID)
|
|
}
|
|
metrics.RecordRateLimitHit("llm-svc", clientIP, reason)
|
|
|
|
sendLimitEmail(config, userID, tier, limits.LLMRequestsPerDay, limits.LLMRequestsPerDay)
|
|
|
|
return c.Status(429).JSON(fiber.Map{
|
|
"error": reason,
|
|
"tier": tier,
|
|
"dailyLimit": limits.LLMRequestsPerDay,
|
|
"tokenLimit": limits.LLMTokensPerDay,
|
|
"upgradeUrl": "/settings/billing",
|
|
})
|
|
}
|
|
|
|
checkLimitWarning(config, userID, tier)
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func checkLimitWarning(config LLMLimitsConfig, userID, tier string) {
|
|
if config.UsageRepo == nil || config.EmailSender == nil || !config.EmailSender.IsConfigured() {
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
todayUsage, err := config.UsageRepo.GetTodayUsage(context.Background(), userID)
|
|
if err != nil || todayUsage == nil {
|
|
return
|
|
}
|
|
|
|
limits := usage.GetLimits(tier)
|
|
if limits.LLMRequestsPerDay == 0 {
|
|
return
|
|
}
|
|
|
|
percentage := (todayUsage.LLMRequests * 100) / limits.LLMRequestsPerDay
|
|
if percentage >= 80 && percentage < 100 {
|
|
userEmail := getUserEmail(config.UsageRepo, userID)
|
|
if userEmail != "" {
|
|
if err := config.EmailSender.SendLimitWarning(userEmail, "", todayUsage.LLMRequests, limits.LLMRequestsPerDay, tier); err != nil {
|
|
log.Printf("[email] Limit warning send error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func sendLimitEmail(config LLMLimitsConfig, userID, tier string, usageCount, limitCount int) {
|
|
if config.EmailSender == nil || !config.EmailSender.IsConfigured() {
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
userEmail := getUserEmail(config.UsageRepo, userID)
|
|
if userEmail != "" {
|
|
if err := config.EmailSender.SendLimitWarning(userEmail, "", usageCount, limitCount, tier); err != nil {
|
|
log.Printf("[email] Limit exceeded email error: %v", err)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func getUserEmail(repo *usage.Repository, userID string) string {
|
|
if repo == nil {
|
|
return ""
|
|
}
|
|
emailAddr, err := repo.GetUserEmail(context.Background(), userID)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return emailAddr
|
|
}
|
|
|
|
type UsageTracker struct {
|
|
repo *usage.Repository
|
|
}
|
|
|
|
func NewUsageTracker(db *sql.DB) *UsageTracker {
|
|
return &UsageTracker{
|
|
repo: usage.NewRepository(db),
|
|
}
|
|
}
|
|
|
|
func (t *UsageTracker) RunMigrations(ctx context.Context) error {
|
|
return t.repo.RunMigrations(ctx)
|
|
}
|
|
|
|
func (t *UsageTracker) TrackAPIRequest(c *fiber.Ctx) {
|
|
userID := GetUserID(c)
|
|
if userID == "" {
|
|
return
|
|
}
|
|
tier := GetUserTier(c)
|
|
if tier == "" {
|
|
tier = "free"
|
|
}
|
|
go t.repo.IncrementAPIRequests(context.Background(), userID, tier)
|
|
}
|
|
|
|
func (t *UsageTracker) TrackLLMRequest(ctx context.Context, userID, tier string, tokens int) {
|
|
if userID == "" {
|
|
return
|
|
}
|
|
if tier == "" {
|
|
tier = "free"
|
|
}
|
|
go t.repo.IncrementLLMUsage(ctx, userID, tier, tokens)
|
|
}
|
|
|
|
func (t *UsageTracker) TrackSearchRequest(c *fiber.Ctx) {
|
|
userID := GetUserID(c)
|
|
if userID == "" {
|
|
return
|
|
}
|
|
tier := GetUserTier(c)
|
|
if tier == "" {
|
|
tier = "free"
|
|
}
|
|
go t.repo.IncrementSearchRequests(context.Background(), userID, tier)
|
|
}
|
|
|
|
func (t *UsageTracker) GetRepository() *usage.Repository {
|
|
return t.repo
|
|
}
|
|
|
|
func UsageTracking(tracker *UsageTracker) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
if tracker != nil {
|
|
tracker.TrackAPIRequest(c)
|
|
}
|
|
return c.Next()
|
|
}
|
|
}
|