Files
gooseek/backend/cmd/llm-svc/main.go
home 7a40ff629e
Some checks failed
Build and Deploy GooSeek / build-and-deploy (push) Failing after 8m25s
feat: LLM routing by tier (free→Ollama, pro→Timeweb)
- Add tier-based provider routing in llm-svc
  - free tier → Ollama (local qwen3.5:9b)
  - pro/business → Timeweb Cloud AI
- Add /api/v1/embed endpoint for embeddings via Ollama
- Update Ollama client: qwen3.5:9b default, remove auth
- Add GenerateEmbedding() function for qwen3-embedding:0.6b
- Add Ollama K8s deployment with GPU support (RTX 4060 Ti)
- Add monitoring stack (Prometheus, Grafana, Alertmanager)
- Add Grafana dashboards for LLM and security metrics
- Update deploy.sh with monitoring and Ollama deployment

Made-with: Cursor
2026-03-03 02:25:22 +03:00

396 lines
9.8 KiB
Go

package main
import (
"bufio"
"context"
"database/sql"
"fmt"
"log"
"os"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/usage"
"github.com/gooseek/backend/pkg/config"
"github.com/gooseek/backend/pkg/metrics"
"github.com/gooseek/backend/pkg/middleware"
"github.com/gooseek/backend/pkg/ndjson"
_ "github.com/lib/pq"
)
type GenerateRequest struct {
ProviderID string `json:"providerId"`
ModelKey string `json:"key"`
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
Options struct {
MaxTokens int `json:"maxTokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
} `json:"options"`
}
type EmbedRequest struct {
Input string `json:"input"`
Model string `json:"model,omitempty"`
}
type ProviderRouting struct {
ProviderID string
ModelKey string
}
func resolveProvider(cfg *config.Config, tier string, requestedProvider string, requestedModel string) ProviderRouting {
if tier == "free" || tier == "" {
return ProviderRouting{
ProviderID: "ollama",
ModelKey: cfg.OllamaModelKey,
}
}
if requestedProvider != "" && requestedProvider != "auto" {
return ProviderRouting{
ProviderID: requestedProvider,
ModelKey: requestedModel,
}
}
if cfg.TimewebAgentAccessID != "" && cfg.TimewebAPIKey != "" {
return ProviderRouting{
ProviderID: "timeweb",
ModelKey: requestedModel,
}
}
if cfg.OpenAIAPIKey != "" {
return ProviderRouting{
ProviderID: "openai",
ModelKey: "gpt-4o-mini",
}
}
return ProviderRouting{
ProviderID: "ollama",
ModelKey: cfg.OllamaModelKey,
}
}
func main() {
cfg, err := config.Load()
if err != nil {
log.Fatal("Failed to load config:", err)
}
var usageRepo *usage.Repository
if cfg.DatabaseURL != "" {
db, err := sql.Open("postgres", cfg.DatabaseURL)
if err != nil {
log.Printf("Usage tracking unavailable: %v", err)
} else {
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
defer db.Close()
usageRepo = usage.NewRepository(db)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if err := usageRepo.RunMigrations(ctx); err != nil {
log.Printf("Usage migrations warning: %v", err)
}
cancel()
log.Println("Usage tracking enabled")
}
}
app := fiber.New(fiber.Config{
StreamRequestBody: true,
BodyLimit: 10 * 1024 * 1024,
ReadTimeout: time.Minute,
WriteTimeout: 5 * time.Minute,
IdleTimeout: 2 * time.Minute,
})
app.Use(logger.New())
app.Use(cors.New())
app.Use(metrics.PrometheusMiddleware(metrics.MetricsConfig{
ServiceName: "llm-svc",
}))
app.Get("/health", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{"status": "ok"})
})
app.Get("/ready", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{"status": "ready"})
})
app.Get("/metrics", metrics.MetricsHandler())
app.Get("/api/v1/providers", func(c *fiber.Ctx) error {
providers := []fiber.Map{}
providers = append(providers, fiber.Map{
"id": "ollama",
"name": "GooSeek AI (Бесплатно)",
"models": []string{cfg.OllamaModelKey},
"tier": "free",
"isLocal": true,
})
if cfg.TimewebAgentAccessID != "" && cfg.TimewebAPIKey != "" {
providers = append(providers, fiber.Map{
"id": "timeweb",
"name": "Timeweb Cloud AI (Pro)",
"models": []string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "gemini-1.5-pro"},
"tier": "pro",
})
}
if cfg.OpenAIAPIKey != "" {
providers = append(providers, fiber.Map{
"id": "openai",
"name": "OpenAI",
"models": []string{"gpt-4o", "gpt-4o-mini", "gpt-4-turbo"},
"tier": "pro",
})
}
if cfg.AnthropicAPIKey != "" {
providers = append(providers, fiber.Map{
"id": "anthropic",
"name": "Anthropic",
"models": []string{"claude-3-5-sonnet-20241022", "claude-3-opus-20240229"},
"tier": "pro",
})
}
if cfg.GeminiAPIKey != "" {
providers = append(providers, fiber.Map{
"id": "gemini",
"name": "Google Gemini",
"models": []string{"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-exp"},
"tier": "pro",
})
}
return c.JSON(fiber.Map{
"providers": providers,
"envOnlyMode": true,
})
})
app.Get("/api/v1/providers/ui-config", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{
"sections": []interface{}{},
})
})
llmAPI := app.Group("/api/v1", middleware.JWT(middleware.JWTConfig{
Secret: cfg.JWTSecret,
AuthSvcURL: cfg.AuthSvcURL,
AllowGuest: false,
}), middleware.LLMLimits(middleware.LLMLimitsConfig{
UsageRepo: usageRepo,
}))
llmAPI.Post("/generate", func(c *fiber.Ctx) error {
startTime := time.Now()
userID := middleware.GetUserID(c)
tier := middleware.GetUserTier(c)
clientIP := c.IP()
if tier == "" {
tier = "free"
}
var req GenerateRequest
if err := c.BodyParser(&req); err != nil {
metrics.RecordLLMError(req.ProviderID, "invalid_request")
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
}
if len(req.Messages) == 0 {
metrics.RecordLLMError(req.ProviderID, "empty_messages")
return c.Status(400).JSON(fiber.Map{"error": "Messages required"})
}
limits := usage.GetLimits(tier)
if req.Options.MaxTokens == 0 || req.Options.MaxTokens > limits.MaxTokensPerReq {
if tier == "free" && req.Options.MaxTokens > limits.MaxTokensPerReq {
metrics.RecordFreeTierLimitExceeded(userID, "max_tokens")
}
req.Options.MaxTokens = limits.MaxTokensPerReq
}
routing := resolveProvider(cfg, tier, req.ProviderID, req.ModelKey)
providerID := routing.ProviderID
modelKey := routing.ModelKey
metrics.RecordLLMRequest(providerID, modelKey, tier, userID)
client, err := llm.NewClient(llm.ProviderConfig{
ProviderID: providerID,
ModelKey: modelKey,
APIKey: getAPIKey(cfg, providerID),
BaseURL: getBaseURL(cfg, providerID),
AgentAccessID: cfg.TimewebAgentAccessID,
})
if err != nil {
metrics.RecordLLMError(req.ProviderID, "client_init_error")
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
}
messages := make([]llm.Message, len(req.Messages))
for i, m := range req.Messages {
messages[i] = llm.Message{
Role: llm.Role(m.Role),
Content: m.Content,
}
}
ctx, cancel := context.WithTimeout(context.Background(), cfg.LLMTimeout)
defer cancel()
if req.Options.Stream {
stream, err := client.StreamText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{
MaxTokens: req.Options.MaxTokens,
Temperature: req.Options.Temperature,
},
})
if err != nil {
metrics.RecordLLMError(providerID, "stream_error")
metrics.RecordSecurityEvent("llm_error", clientIP, userID)
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
}
c.Set("Content-Type", "application/x-ndjson")
c.Set("Cache-Control", "no-cache")
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
writer := ndjson.NewWriter(w)
tokenCount := 0
for chunk := range stream {
writer.Write(fiber.Map{
"type": "chunk",
"chunk": chunk.ContentChunk,
})
w.Flush()
tokenCount += len(chunk.ContentChunk) / 4
}
writer.Write(fiber.Map{"type": "done"})
metrics.RecordLLMLatency(providerID, modelKey, time.Since(startTime))
metrics.RecordLLMTokens(providerID, tier, userID, tokenCount)
if usageRepo != nil {
go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, tokenCount)
}
})
return nil
}
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{
MaxTokens: req.Options.MaxTokens,
Temperature: req.Options.Temperature,
},
})
if err != nil {
metrics.RecordLLMError(providerID, "generate_error")
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
}
tokenCount := len(response) / 4
metrics.RecordLLMLatency(providerID, modelKey, time.Since(startTime))
metrics.RecordLLMTokens(providerID, tier, userID, tokenCount)
if usageRepo != nil {
go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, tokenCount)
}
return c.JSON(fiber.Map{
"content": response,
})
})
llmAPI.Post("/embed", func(c *fiber.Ctx) error {
userID := middleware.GetUserID(c)
tier := middleware.GetUserTier(c)
if tier == "" {
tier = "free"
}
var req EmbedRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
}
if req.Input == "" {
return c.Status(400).JSON(fiber.Map{"error": "Input text required"})
}
model := req.Model
if model == "" {
model = cfg.OllamaEmbeddingModel
}
embeddings, err := llm.GenerateEmbedding(cfg.OllamaBaseURL, model, req.Input)
if err != nil {
metrics.RecordLLMError("ollama", "embed_error")
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
}
metrics.RecordLLMRequest("ollama", model, tier, userID)
return c.JSON(fiber.Map{
"embedding": embeddings,
"model": model,
})
})
port := cfg.LLMSvcPort
log.Printf("llm-svc listening on :%d", port)
log.Fatal(app.Listen(fmt.Sprintf(":%d", port)))
}
func getAPIKey(cfg *config.Config, providerID string) string {
switch providerID {
case "openai":
return cfg.OpenAIAPIKey
case "timeweb":
return cfg.TimewebAPIKey
case "anthropic":
return cfg.AnthropicAPIKey
case "gemini", "google":
return cfg.GeminiAPIKey
default:
return ""
}
}
func getBaseURL(cfg *config.Config, providerID string) string {
switch providerID {
case "timeweb":
return cfg.TimewebAPIBaseURL
case "ollama":
return cfg.OllamaBaseURL
default:
return ""
}
}
func init() {
if os.Getenv("PORT") == "" {
os.Setenv("PORT", "3020")
}
}