package main import ( "bufio" "context" "database/sql" "fmt" "log" "os" "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gooseek/backend/internal/llm" "github.com/gooseek/backend/internal/usage" "github.com/gooseek/backend/pkg/config" "github.com/gooseek/backend/pkg/email" "github.com/gooseek/backend/pkg/metrics" "github.com/gooseek/backend/pkg/middleware" "github.com/gooseek/backend/pkg/ndjson" _ "github.com/lib/pq" ) type GenerateRequest struct { ProviderID string `json:"providerId"` ModelKey string `json:"key"` Messages []struct { Role string `json:"role"` Content string `json:"content"` } `json:"messages"` Options struct { MaxTokens int `json:"maxTokens"` Temperature float64 `json:"temperature"` Stream bool `json:"stream"` } `json:"options"` } type EmbedRequest struct { Input string `json:"input"` Model string `json:"model,omitempty"` } type ProviderRouting struct { ProviderID string ModelKey string } func resolveProvider(cfg *config.Config, tier string, requestedProvider string, requestedModel string) ProviderRouting { if tier == "free" || tier == "" { return ProviderRouting{ ProviderID: "ollama", ModelKey: cfg.OllamaModelKey, } } if requestedProvider != "" && requestedProvider != "auto" { return ProviderRouting{ ProviderID: requestedProvider, ModelKey: requestedModel, } } if cfg.TimewebAgentAccessID != "" && cfg.TimewebAPIKey != "" { return ProviderRouting{ ProviderID: "timeweb", ModelKey: requestedModel, } } if cfg.OpenAIAPIKey != "" { return ProviderRouting{ ProviderID: "openai", ModelKey: "gpt-4o-mini", } } return ProviderRouting{ ProviderID: "ollama", ModelKey: cfg.OllamaModelKey, } } func main() { cfg, err := config.Load() if err != nil { log.Fatal("Failed to load config:", err) } var usageRepo *usage.Repository if cfg.DatabaseURL != "" { db, err := sql.Open("postgres", cfg.DatabaseURL) if err != nil { log.Printf("Usage tracking unavailable: %v", err) } else { db.SetMaxOpenConns(5) db.SetMaxIdleConns(2) defer db.Close() usageRepo = usage.NewRepository(db) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if err := usageRepo.RunMigrations(ctx); err != nil { log.Printf("Usage migrations warning: %v", err) } cancel() log.Println("Usage tracking enabled") } } app := fiber.New(fiber.Config{ StreamRequestBody: true, BodyLimit: 10 * 1024 * 1024, ReadTimeout: time.Minute, WriteTimeout: 5 * time.Minute, IdleTimeout: 2 * time.Minute, }) app.Use(logger.New()) app.Use(cors.New()) app.Use(metrics.PrometheusMiddleware(metrics.MetricsConfig{ ServiceName: "llm-svc", })) app.Get("/health", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{"status": "ok"}) }) app.Get("/ready", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{"status": "ready"}) }) app.Get("/metrics", metrics.MetricsHandler()) app.Get("/api/v1/providers", func(c *fiber.Ctx) error { providers := []fiber.Map{} providers = append(providers, fiber.Map{ "id": "ollama", "name": "GooSeek AI (Бесплатно)", "models": []string{cfg.OllamaModelKey}, "tier": "free", "isLocal": true, }) if cfg.TimewebAgentAccessID != "" && cfg.TimewebAPIKey != "" { providers = append(providers, fiber.Map{ "id": "timeweb", "name": "Timeweb Cloud AI (Pro)", "models": []string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "gemini-1.5-pro"}, "tier": "pro", }) } if cfg.OpenAIAPIKey != "" { providers = append(providers, fiber.Map{ "id": "openai", "name": "OpenAI", "models": []string{"gpt-4o", "gpt-4o-mini", "gpt-4-turbo"}, "tier": "pro", }) } if cfg.AnthropicAPIKey != "" { providers = append(providers, fiber.Map{ "id": "anthropic", "name": "Anthropic", "models": []string{"claude-3-5-sonnet-20241022", "claude-3-opus-20240229"}, "tier": "pro", }) } if cfg.GeminiAPIKey != "" { providers = append(providers, fiber.Map{ "id": "gemini", "name": "Google Gemini", "models": []string{"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-exp"}, "tier": "pro", }) } return c.JSON(fiber.Map{ "providers": providers, "envOnlyMode": true, }) }) app.Get("/api/v1/providers/ui-config", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{ "sections": []interface{}{}, }) }) emailSender := email.NewSender(email.SMTPConfig{ Host: cfg.SMTPHost, Port: cfg.SMTPPort, User: cfg.SMTPUser, Password: cfg.SMTPPassword, From: cfg.SMTPFrom, TLS: cfg.SMTPTLS, SiteURL: cfg.SiteURL, SiteName: cfg.SiteName, }) llmAPI := app.Group("/api/v1", middleware.JWT(middleware.JWTConfig{ Secret: cfg.JWTSecret, AuthSvcURL: cfg.AuthSvcURL, AllowGuest: false, }), middleware.LLMLimits(middleware.LLMLimitsConfig{ UsageRepo: usageRepo, EmailSender: emailSender, })) llmAPI.Post("/generate", func(c *fiber.Ctx) error { startTime := time.Now() userID := middleware.GetUserID(c) tier := middleware.GetUserTier(c) clientIP := c.IP() if tier == "" { tier = "free" } var req GenerateRequest if err := c.BodyParser(&req); err != nil { metrics.RecordLLMError(req.ProviderID, "invalid_request") return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"}) } if len(req.Messages) == 0 { metrics.RecordLLMError(req.ProviderID, "empty_messages") return c.Status(400).JSON(fiber.Map{"error": "Messages required"}) } limits := usage.GetLimits(tier) if req.Options.MaxTokens == 0 || req.Options.MaxTokens > limits.MaxTokensPerReq { if tier == "free" && req.Options.MaxTokens > limits.MaxTokensPerReq { metrics.RecordFreeTierLimitExceeded(userID, "max_tokens") } req.Options.MaxTokens = limits.MaxTokensPerReq } routing := resolveProvider(cfg, tier, req.ProviderID, req.ModelKey) providerID := routing.ProviderID modelKey := routing.ModelKey metrics.RecordLLMRequest(providerID, modelKey, tier, userID) client, err := llm.NewClient(llm.ProviderConfig{ ProviderID: providerID, ModelKey: modelKey, APIKey: getAPIKey(cfg, providerID), BaseURL: getBaseURL(cfg, providerID), AgentAccessID: cfg.TimewebAgentAccessID, }) if err != nil { metrics.RecordLLMError(req.ProviderID, "client_init_error") return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } messages := make([]llm.Message, len(req.Messages)) for i, m := range req.Messages { messages[i] = llm.Message{ Role: llm.Role(m.Role), Content: m.Content, } } ctx, cancel := context.WithTimeout(context.Background(), cfg.LLMTimeout) defer cancel() if req.Options.Stream { stream, err := client.StreamText(ctx, llm.StreamRequest{ Messages: messages, Options: llm.StreamOptions{ MaxTokens: req.Options.MaxTokens, Temperature: req.Options.Temperature, }, }) if err != nil { metrics.RecordLLMError(providerID, "stream_error") metrics.RecordSecurityEvent("llm_error", clientIP, userID) return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } c.Set("Content-Type", "application/x-ndjson") c.Set("Cache-Control", "no-cache") c.Context().SetBodyStreamWriter(func(w *bufio.Writer) { writer := ndjson.NewWriter(w) tokenCount := 0 for chunk := range stream { writer.Write(fiber.Map{ "type": "chunk", "chunk": chunk.ContentChunk, }) w.Flush() tokenCount += len(chunk.ContentChunk) / 4 } writer.Write(fiber.Map{"type": "done"}) metrics.RecordLLMLatency(providerID, modelKey, time.Since(startTime)) metrics.RecordLLMTokens(providerID, tier, userID, tokenCount) if usageRepo != nil { go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, tokenCount) } }) return nil } response, err := client.GenerateText(ctx, llm.StreamRequest{ Messages: messages, Options: llm.StreamOptions{ MaxTokens: req.Options.MaxTokens, Temperature: req.Options.Temperature, }, }) if err != nil { metrics.RecordLLMError(providerID, "generate_error") return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } tokenCount := len(response) / 4 metrics.RecordLLMLatency(providerID, modelKey, time.Since(startTime)) metrics.RecordLLMTokens(providerID, tier, userID, tokenCount) if usageRepo != nil { go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, tokenCount) } return c.JSON(fiber.Map{ "content": response, }) }) llmAPI.Post("/embed", func(c *fiber.Ctx) error { userID := middleware.GetUserID(c) tier := middleware.GetUserTier(c) if tier == "" { tier = "free" } var req EmbedRequest if err := c.BodyParser(&req); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"}) } if req.Input == "" { return c.Status(400).JSON(fiber.Map{"error": "Input text required"}) } model := req.Model if model == "" { model = cfg.OllamaEmbeddingModel } embeddings, err := llm.GenerateEmbedding(cfg.OllamaBaseURL, model, req.Input) if err != nil { metrics.RecordLLMError("ollama", "embed_error") return c.Status(500).JSON(fiber.Map{"error": err.Error()}) } metrics.RecordLLMRequest("ollama", model, tier, userID) return c.JSON(fiber.Map{ "embedding": embeddings, "model": model, }) }) port := cfg.LLMSvcPort log.Printf("llm-svc listening on :%d", port) log.Fatal(app.Listen(fmt.Sprintf(":%d", port))) } func getAPIKey(cfg *config.Config, providerID string) string { switch providerID { case "openai": return cfg.OpenAIAPIKey case "timeweb": return cfg.TimewebAPIKey case "anthropic": return cfg.AnthropicAPIKey case "gemini", "google": return cfg.GeminiAPIKey default: return "" } } func getBaseURL(cfg *config.Config, providerID string) string { switch providerID { case "timeweb": return cfg.TimewebAPIBaseURL case "ollama": return cfg.OllamaBaseURL default: return "" } } func init() { if os.Getenv("PORT") == "" { os.Setenv("PORT", "3020") } }