package main import ( "bufio" "context" "database/sql" "fmt" "log" "os" "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gooseek/backend/internal/llm" "github.com/gooseek/backend/internal/usage" "github.com/gooseek/backend/pkg/config" "github.com/gooseek/backend/pkg/middleware" "github.com/gooseek/backend/pkg/ndjson" _ "github.com/lib/pq" ) type GenerateRequest struct { ProviderID string `json:"providerId"` ModelKey string `json:"key"` Messages []struct { Role string `json:"role"` Content string `json:"content"` } `json:"messages"` Options struct { MaxTokens int `json:"maxTokens"` Temperature float64 `json:"temperature"` Stream bool `json:"stream"` } `json:"options"` } func main() { cfg, err := config.Load() if err != nil { log.Fatal("Failed to load config:", err) } var usageRepo *usage.Repository if cfg.DatabaseURL != "" { db, err := sql.Open("postgres", cfg.DatabaseURL) if err != nil { log.Printf("Usage tracking unavailable: %v", err) } else { db.SetMaxOpenConns(5) db.SetMaxIdleConns(2) defer db.Close() usageRepo = usage.NewRepository(db) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if err := usageRepo.RunMigrations(ctx); err != nil { log.Printf("Usage migrations warning: %v", err) } cancel() log.Println("Usage tracking enabled") } } app := fiber.New(fiber.Config{ StreamRequestBody: true, BodyLimit: 10 * 1024 * 1024, ReadTimeout: time.Minute, WriteTimeout: 5 * time.Minute, IdleTimeout: 2 * time.Minute, }) app.Use(logger.New()) app.Use(cors.New()) app.Get("/health", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{"status": "ok"}) }) app.Get("/api/v1/providers", func(c *fiber.Ctx) error { providers := []fiber.Map{} if cfg.OpenAIAPIKey != "" { providers = append(providers, fiber.Map{ "id": "openai", "name": "OpenAI", "models": []string{"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"}, }) } if cfg.AnthropicAPIKey != "" { providers = append(providers, fiber.Map{ "id": "anthropic", "name": "Anthropic", "models": []string{"claude-3-5-sonnet-20241022", "claude-3-opus-20240229", "claude-3-haiku-20240307"}, }) } if cfg.GeminiAPIKey != "" { providers = append(providers, fiber.Map{ "id": "gemini", "name": "Google Gemini", "models": []string{"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-exp"}, }) } return c.JSON(fiber.Map{ "providers": providers, "envOnlyMode": true, }) }) app.Get("/api/v1/providers/ui-config", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{ "sections": []interface{}{}, }) }) llmAPI := app.Group("/api/v1", middleware.JWT(middleware.JWTConfig{ Secret: cfg.JWTSecret, AuthSvcURL: cfg.AuthSvcURL, AllowGuest: false, }), middleware.LLMLimits(middleware.LLMLimitsConfig{ UsageRepo: usageRepo, })) llmAPI.Post("/generate", func(c *fiber.Ctx) error { userID := middleware.GetUserID(c) tier := middleware.GetUserTier(c) if tier == "" { tier = "free" } var req GenerateRequest if err := c.BodyParser(&req); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"}) } if len(req.Messages) == 0 { return c.Status(400).JSON(fiber.Map{"error": "Messages required"}) } limits := usage.GetLimits(tier) if req.Options.MaxTokens == 0 || req.Options.MaxTokens > limits.MaxTokensPerReq { req.Options.MaxTokens = limits.MaxTokensPerReq } client, err := llm.NewClient(llm.ProviderConfig{ ProviderID: req.ProviderID, ModelKey: req.ModelKey, APIKey: getAPIKey(cfg, req.ProviderID), }) if err != nil { return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } messages := make([]llm.Message, len(req.Messages)) for i, m := range req.Messages { messages[i] = llm.Message{ Role: llm.Role(m.Role), Content: m.Content, } } ctx, cancel := context.WithTimeout(context.Background(), cfg.LLMTimeout) defer cancel() if req.Options.Stream { stream, err := client.StreamText(ctx, llm.StreamRequest{ Messages: messages, Options: llm.StreamOptions{ MaxTokens: req.Options.MaxTokens, Temperature: req.Options.Temperature, }, }) if err != nil { return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } c.Set("Content-Type", "application/x-ndjson") c.Set("Cache-Control", "no-cache") c.Context().SetBodyStreamWriter(func(w *bufio.Writer) { writer := ndjson.NewWriter(w) for chunk := range stream { writer.Write(fiber.Map{ "type": "chunk", "chunk": chunk.ContentChunk, }) w.Flush() } writer.Write(fiber.Map{"type": "done"}) }) return nil } response, err := client.GenerateText(ctx, llm.StreamRequest{ Messages: messages, Options: llm.StreamOptions{ MaxTokens: req.Options.MaxTokens, Temperature: req.Options.Temperature, }, }) if err != nil { return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } if usageRepo != nil { go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, len(response)/4) } return c.JSON(fiber.Map{ "content": response, }) }) llmAPI.Post("/embed", func(c *fiber.Ctx) error { return c.Status(501).JSON(fiber.Map{"error": "Not implemented"}) }) port := cfg.LLMSvcPort log.Printf("llm-svc listening on :%d", port) log.Fatal(app.Listen(fmt.Sprintf(":%d", port))) } func getAPIKey(cfg *config.Config, providerID string) string { switch providerID { case "openai", "timeweb": return cfg.OpenAIAPIKey case "anthropic": return cfg.AnthropicAPIKey case "gemini", "google": return cfg.GeminiAPIKey default: return "" } } func init() { if os.Getenv("PORT") == "" { os.Setenv("PORT", "3020") } }