feat: auth service + security audit fixes + cleanup legacy services
Major changes:
- Add auth-svc: JWT auth, register/login/refresh, password reset
- Add auth UI: modals, pages (/login, /register, /forgot-password)
- Add usage tracking (usage_metrics table, daily limits)
- Add tiered rate limiting (free/pro/business)
- Add LLM usage limits per tier
Security fixes:
- All repos now require userID for Update/Delete operations
- JWT middleware in chat-svc, llm-svc, agent-svc, discover-svc
- ErrNotFound/ErrForbidden errors for proper access control
Cleanup:
- Remove legacy TypeScript services/ directory
- Remove computer-svc (to be reimplemented)
- Remove old deploy/docker configs
New files:
- backend/cmd/auth-svc/main.go
- backend/internal/auth/{types,repository}.go
- backend/internal/usage/{types,repository}.go
- backend/pkg/middleware/{llm_limits,ratelimit_tiered}.go
- backend/webui/src/components/auth/*
- backend/webui/src/app/(auth)/*
Made-with: Cursor
This commit is contained in:
440
backend/cmd/admin-svc/main.go
Normal file
440
backend/cmd/admin-svc/main.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gooseek/backend/internal/admin"
|
||||
"github.com/gooseek/backend/internal/db"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load config:", err)
|
||||
}
|
||||
|
||||
var database *db.PostgresDB
|
||||
if cfg.DatabaseURL != "" {
|
||||
database, err = db.NewPostgresDB(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to connect to database:", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if err := database.RunMigrations(ctx); err != nil {
|
||||
log.Printf("Migration warning: %v", err)
|
||||
}
|
||||
if err := admin.RunAdminMigrations(ctx, database.DB()); err != nil {
|
||||
log.Printf("Admin migrations warning: %v", err)
|
||||
}
|
||||
cancel()
|
||||
log.Println("PostgreSQL connected")
|
||||
} else {
|
||||
log.Fatal("DATABASE_URL is required for admin-svc")
|
||||
}
|
||||
|
||||
userRepo := admin.NewUserRepository(database.DB())
|
||||
postRepo := admin.NewPostRepository(database.DB())
|
||||
settingsRepo := admin.NewSettingsRepository(database.DB())
|
||||
discoverRepo := admin.NewDiscoverConfigRepository(database.DB())
|
||||
auditRepo := admin.NewAuditRepository(database.DB())
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 50 * 1024 * 1024,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
})
|
||||
|
||||
app.Use(logger.New())
|
||||
app.Use(cors.New())
|
||||
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ok"})
|
||||
})
|
||||
|
||||
app.Get("/ready", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ready"})
|
||||
})
|
||||
|
||||
api := app.Group("/api/v1/admin")
|
||||
|
||||
api.Use(middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: false,
|
||||
}))
|
||||
api.Use(middleware.RequireRole("admin"))
|
||||
|
||||
api.Get("/dashboard", func(c *fiber.Ctx) error {
|
||||
stats, err := getDashboardStats(c.Context(), userRepo, postRepo)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(stats)
|
||||
})
|
||||
|
||||
usersGroup := api.Group("/users")
|
||||
{
|
||||
usersGroup.Get("/", func(c *fiber.Ctx) error {
|
||||
page := c.QueryInt("page", 1)
|
||||
perPage := c.QueryInt("perPage", 20)
|
||||
search := c.Query("search")
|
||||
|
||||
users, total, err := userRepo.List(c.Context(), page, perPage, search)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(admin.UserListResponse{
|
||||
Users: users,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
})
|
||||
})
|
||||
|
||||
usersGroup.Get("/:id", func(c *fiber.Ctx) error {
|
||||
user, err := userRepo.GetByID(c.Context(), c.Params("id"))
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "User not found"})
|
||||
}
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
usersGroup.Post("/", func(c *fiber.Ctx) error {
|
||||
var req admin.UserCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
user, err := userRepo.Create(c.Context(), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "create", "user", user.ID)
|
||||
return c.Status(201).JSON(user)
|
||||
})
|
||||
|
||||
usersGroup.Patch("/:id", func(c *fiber.Ctx) error {
|
||||
var req admin.UserUpdateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
user, err := userRepo.Update(c.Context(), c.Params("id"), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "update", "user", user.ID)
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
usersGroup.Delete("/:id", func(c *fiber.Ctx) error {
|
||||
if err := userRepo.Delete(c.Context(), c.Params("id")); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "delete", "user", c.Params("id"))
|
||||
return c.SendStatus(204)
|
||||
})
|
||||
}
|
||||
|
||||
postsGroup := api.Group("/posts")
|
||||
{
|
||||
postsGroup.Get("/", func(c *fiber.Ctx) error {
|
||||
page := c.QueryInt("page", 1)
|
||||
perPage := c.QueryInt("perPage", 20)
|
||||
status := c.Query("status")
|
||||
category := c.Query("category")
|
||||
|
||||
posts, total, err := postRepo.List(c.Context(), page, perPage, status, category)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(admin.PostListResponse{
|
||||
Posts: posts,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
})
|
||||
})
|
||||
|
||||
postsGroup.Get("/:id", func(c *fiber.Ctx) error {
|
||||
post, err := postRepo.GetByID(c.Context(), c.Params("id"))
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Post not found"})
|
||||
}
|
||||
return c.JSON(post)
|
||||
})
|
||||
|
||||
postsGroup.Post("/", func(c *fiber.Ctx) error {
|
||||
var req admin.PostCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
authorID := middleware.GetUserID(c)
|
||||
post, err := postRepo.Create(c.Context(), authorID, &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "create", "post", post.ID)
|
||||
return c.Status(201).JSON(post)
|
||||
})
|
||||
|
||||
postsGroup.Patch("/:id", func(c *fiber.Ctx) error {
|
||||
var req admin.PostUpdateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
post, err := postRepo.Update(c.Context(), c.Params("id"), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "update", "post", post.ID)
|
||||
return c.JSON(post)
|
||||
})
|
||||
|
||||
postsGroup.Delete("/:id", func(c *fiber.Ctx) error {
|
||||
if err := postRepo.Delete(c.Context(), c.Params("id")); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "delete", "post", c.Params("id"))
|
||||
return c.SendStatus(204)
|
||||
})
|
||||
|
||||
postsGroup.Post("/:id/publish", func(c *fiber.Ctx) error {
|
||||
post, err := postRepo.Publish(c.Context(), c.Params("id"))
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "publish", "post", post.ID)
|
||||
return c.JSON(post)
|
||||
})
|
||||
}
|
||||
|
||||
settingsGroup := api.Group("/settings")
|
||||
{
|
||||
settingsGroup.Get("/", func(c *fiber.Ctx) error {
|
||||
settings, err := settingsRepo.Get(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(settings)
|
||||
})
|
||||
|
||||
settingsGroup.Patch("/", func(c *fiber.Ctx) error {
|
||||
var settings admin.PlatformSettings
|
||||
if err := c.BodyParser(&settings); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
updated, err := settingsRepo.Update(c.Context(), &settings)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "update", "settings", "platform")
|
||||
return c.JSON(updated)
|
||||
})
|
||||
|
||||
settingsGroup.Get("/features", func(c *fiber.Ctx) error {
|
||||
features, err := settingsRepo.GetFeatures(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(features)
|
||||
})
|
||||
|
||||
settingsGroup.Patch("/features", func(c *fiber.Ctx) error {
|
||||
var features admin.FeatureFlags
|
||||
if err := c.BodyParser(&features); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if err := settingsRepo.UpdateFeatures(c.Context(), &features); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "update", "settings", "features")
|
||||
return c.JSON(features)
|
||||
})
|
||||
}
|
||||
|
||||
discoverGroup := api.Group("/discover")
|
||||
{
|
||||
discoverGroup.Get("/categories", func(c *fiber.Ctx) error {
|
||||
categories, err := discoverRepo.ListCategories(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(fiber.Map{"categories": categories})
|
||||
})
|
||||
|
||||
discoverGroup.Post("/categories", func(c *fiber.Ctx) error {
|
||||
var req admin.DiscoverCategoryCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
category, err := discoverRepo.CreateCategory(c.Context(), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "create", "discover_category", category.ID)
|
||||
return c.Status(201).JSON(category)
|
||||
})
|
||||
|
||||
discoverGroup.Patch("/categories/:id", func(c *fiber.Ctx) error {
|
||||
var req admin.DiscoverCategoryUpdateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
category, err := discoverRepo.UpdateCategory(c.Context(), c.Params("id"), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "update", "discover_category", category.ID)
|
||||
return c.JSON(category)
|
||||
})
|
||||
|
||||
discoverGroup.Delete("/categories/:id", func(c *fiber.Ctx) error {
|
||||
if err := discoverRepo.DeleteCategory(c.Context(), c.Params("id")); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "delete", "discover_category", c.Params("id"))
|
||||
return c.SendStatus(204)
|
||||
})
|
||||
|
||||
discoverGroup.Post("/categories/reorder", func(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
Order []string `json:"order"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if err := discoverRepo.ReorderCategories(c.Context(), req.Order); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "reorder", "discover_categories", "")
|
||||
return c.SendStatus(204)
|
||||
})
|
||||
|
||||
discoverGroup.Get("/sources", func(c *fiber.Ctx) error {
|
||||
sources, err := discoverRepo.ListSources(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(fiber.Map{"sources": sources})
|
||||
})
|
||||
|
||||
discoverGroup.Post("/sources", func(c *fiber.Ctx) error {
|
||||
var req admin.DiscoverSourceCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
source, err := discoverRepo.CreateSource(c.Context(), &req)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "create", "discover_source", source.ID)
|
||||
return c.Status(201).JSON(source)
|
||||
})
|
||||
|
||||
discoverGroup.Delete("/sources/:id", func(c *fiber.Ctx) error {
|
||||
if err := discoverRepo.DeleteSource(c.Context(), c.Params("id")); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
logAudit(c, auditRepo, "delete", "discover_source", c.Params("id"))
|
||||
return c.SendStatus(204)
|
||||
})
|
||||
}
|
||||
|
||||
auditGroup := api.Group("/audit")
|
||||
{
|
||||
auditGroup.Get("/", func(c *fiber.Ctx) error {
|
||||
page := c.QueryInt("page", 1)
|
||||
perPage := c.QueryInt("perPage", 50)
|
||||
action := c.Query("action")
|
||||
resource := c.Query("resource")
|
||||
|
||||
logs, total, err := auditRepo.List(c.Context(), page, perPage, action, resource)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"perPage": perPage,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
port := config.GetEnvInt("ADMIN_SVC_PORT", 3040)
|
||||
log.Printf("admin-svc listening on :%d", port)
|
||||
log.Fatal(app.Listen(fmt.Sprintf(":%d", port)))
|
||||
}
|
||||
|
||||
func getDashboardStats(ctx context.Context, userRepo *admin.UserRepository, postRepo *admin.PostRepository) (*admin.DashboardStats, error) {
|
||||
totalUsers, _ := userRepo.Count(ctx, "")
|
||||
activeUsers, _ := userRepo.CountActive(ctx)
|
||||
totalPosts, _ := postRepo.Count(ctx, "")
|
||||
publishedPosts, _ := postRepo.Count(ctx, "published")
|
||||
|
||||
return &admin.DashboardStats{
|
||||
TotalUsers: totalUsers,
|
||||
ActiveUsers: activeUsers,
|
||||
TotalPosts: totalPosts,
|
||||
PublishedPosts: publishedPosts,
|
||||
StorageUsedMB: 0,
|
||||
StorageLimitMB: 10240,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logAudit(c *fiber.Ctx, repo *admin.AuditRepository, action, resource, resourceID string) {
|
||||
user := middleware.GetUser(c)
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log := &admin.AuditLog{
|
||||
UserID: user.UserID,
|
||||
UserEmail: user.Email,
|
||||
Action: action,
|
||||
Resource: resource,
|
||||
ResourceID: resourceID,
|
||||
IPAddress: c.IP(),
|
||||
UserAgent: c.Get("User-Agent"),
|
||||
}
|
||||
|
||||
go repo.Create(context.Background(), log)
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/gooseek/backend/internal/search"
|
||||
"github.com/gooseek/backend/internal/session"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
"github.com/gooseek/backend/pkg/ndjson"
|
||||
)
|
||||
|
||||
@@ -69,7 +70,13 @@ func main() {
|
||||
return c.JSON(fiber.Map{"status": "ok"})
|
||||
})
|
||||
|
||||
app.Post("/api/v1/agents/search", func(c *fiber.Ctx) error {
|
||||
agents := app.Group("/api/v1/agents", middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: false,
|
||||
}))
|
||||
|
||||
agents.Post("/search", func(c *fiber.Ctx) error {
|
||||
var req SearchRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
@@ -182,6 +189,10 @@ func main() {
|
||||
return nil
|
||||
})
|
||||
|
||||
agents.Get("/status", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ready", "user": middleware.GetUserID(c)})
|
||||
})
|
||||
|
||||
port := cfg.AgentSvcPort
|
||||
log.Printf("agent-svc listening on :%d", port)
|
||||
log.Fatal(app.Listen(fmt.Sprintf(":%d", port)))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
)
|
||||
|
||||
var svcURLs map[string]string
|
||||
@@ -25,6 +26,7 @@ func main() {
|
||||
}
|
||||
|
||||
svcURLs = map[string]string{
|
||||
"auth": cfg.AuthSvcURL,
|
||||
"chat": cfg.ChatSvcURL,
|
||||
"agents": cfg.AgentSvcURL,
|
||||
"search": cfg.SearchSvcURL,
|
||||
@@ -36,7 +38,7 @@ func main() {
|
||||
"discover": cfg.DiscoverSvcURL,
|
||||
"finance": cfg.FinanceHeatmapURL,
|
||||
"learning": cfg.LearningSvcURL,
|
||||
"computer": cfg.ComputerSvcURL,
|
||||
"admin": cfg.AdminSvcURL,
|
||||
}
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
@@ -54,6 +56,21 @@ func main() {
|
||||
AllowMethods: "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||
}))
|
||||
|
||||
app.Use(middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: true,
|
||||
}))
|
||||
|
||||
app.Use(middleware.TieredRateLimit(middleware.TieredRateLimitConfig{
|
||||
Tiers: map[string]middleware.TierConfig{
|
||||
"free": {Max: 60, Window: time.Minute},
|
||||
"pro": {Max: 300, Window: time.Minute},
|
||||
"business": {Max: 1000, Window: time.Minute},
|
||||
},
|
||||
DefaultTier: "free",
|
||||
}))
|
||||
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ok"})
|
||||
})
|
||||
@@ -72,6 +89,8 @@ func main() {
|
||||
|
||||
func getTarget(path string) (base, rewrite string) {
|
||||
switch {
|
||||
case strings.HasPrefix(path, "/api/v1/auth"):
|
||||
return svcURLs["auth"], path
|
||||
case path == "/api/chat" || strings.HasPrefix(path, "/api/chat?"):
|
||||
return svcURLs["chat"], "/api/v1/chat"
|
||||
case strings.HasPrefix(path, "/api/v1/agents"):
|
||||
@@ -102,8 +121,8 @@ func getTarget(path string) (base, rewrite string) {
|
||||
return svcURLs["finance"], path
|
||||
case strings.HasPrefix(path, "/api/v1/learning"):
|
||||
return svcURLs["learning"], path
|
||||
case strings.HasPrefix(path, "/api/v1/computer"):
|
||||
return svcURLs["computer"], path
|
||||
case strings.HasPrefix(path, "/api/v1/admin"):
|
||||
return svcURLs["admin"], path
|
||||
default:
|
||||
return "", ""
|
||||
}
|
||||
@@ -195,11 +214,44 @@ func handleProxy(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: time.Minute}
|
||||
isSSE := strings.Contains(path, "/stream") ||
|
||||
c.Get("Accept") == "text/event-stream"
|
||||
|
||||
timeout := time.Minute
|
||||
if isSSE {
|
||||
timeout = 30 * time.Minute
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return handleFallback(c, path)
|
||||
}
|
||||
|
||||
if isSSE && resp.Header.Get("Content-Type") == "text/event-stream" {
|
||||
c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
c.Set("X-Accel-Buffering", "no")
|
||||
|
||||
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
defer resp.Body.Close()
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
w.Write(buf[:n])
|
||||
w.Flush()
|
||||
}
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
for _, h := range []string{"Content-Type", "Cache-Control", "Set-Cookie"} {
|
||||
|
||||
460
backend/cmd/auth-svc/main.go
Normal file
460
backend/cmd/auth-svc/main.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gooseek/backend/internal/auth"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtSecret string
|
||||
accessTokenTTL = 15 * time.Minute
|
||||
refreshTokenTTL = 7 * 24 * time.Hour
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
)
|
||||
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load config:", err)
|
||||
}
|
||||
|
||||
jwtSecret = cfg.JWTSecret
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Fatal("JWT_SECRET is required")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.DatabaseURL == "" {
|
||||
log.Fatal("DATABASE_URL is required")
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to connect to database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
maxRetries := 30
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
if err := db.Ping(); err == nil {
|
||||
break
|
||||
}
|
||||
log.Printf("Waiting for database (attempt %d/%d)...", i+1, maxRetries)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
authRepo := auth.NewRepository(db)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
if err := authRepo.RunMigrations(ctx); err != nil {
|
||||
log.Printf("Migration warning: %v", err)
|
||||
}
|
||||
cancel()
|
||||
log.Println("Auth database ready")
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
for range ticker.C {
|
||||
authRepo.CleanupExpiredTokens(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 10 * 1024 * 1024,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
})
|
||||
|
||||
app.Use(logger.New())
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: "*",
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
|
||||
AllowMethods: "GET, POST, PUT, DELETE, OPTIONS",
|
||||
AllowCredentials: true,
|
||||
}))
|
||||
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ok"})
|
||||
})
|
||||
|
||||
app.Get("/ready", func(c *fiber.Ctx) error {
|
||||
if err := db.Ping(); err != nil {
|
||||
return c.Status(503).JSON(fiber.Map{"status": "database unavailable"})
|
||||
}
|
||||
return c.JSON(fiber.Map{"status": "ready"})
|
||||
})
|
||||
|
||||
api := app.Group("/api/v1/auth")
|
||||
|
||||
api.Post("/register", func(c *fiber.Ctx) error {
|
||||
var req auth.RegisterRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Password == "" || req.Name == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Email, password and name are required"})
|
||||
}
|
||||
|
||||
if !emailRegex.MatchString(req.Email) {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid email format"})
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Password must be at least 8 characters"})
|
||||
}
|
||||
|
||||
user, err := authRepo.CreateUser(c.Context(), req.Email, req.Password, req.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrEmailExists) {
|
||||
return c.Status(409).JSON(fiber.Map{"error": "Email already registered"})
|
||||
}
|
||||
if errors.Is(err, auth.ErrWeakPassword) {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Password too weak"})
|
||||
}
|
||||
log.Printf("Register error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Registration failed"})
|
||||
}
|
||||
|
||||
tokens, err := generateTokens(c, authRepo, user)
|
||||
if err != nil {
|
||||
log.Printf("Token generation error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to generate tokens"})
|
||||
}
|
||||
|
||||
return c.Status(201).JSON(tokens)
|
||||
})
|
||||
|
||||
api.Post("/login", func(c *fiber.Ctx) error {
|
||||
var req auth.LoginRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Password == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Email and password are required"})
|
||||
}
|
||||
|
||||
user, err := authRepo.ValidatePassword(c.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrUserNotFound) || errors.Is(err, auth.ErrInvalidPassword) {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Invalid email or password"})
|
||||
}
|
||||
log.Printf("Login error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Login failed"})
|
||||
}
|
||||
|
||||
tokens, err := generateTokens(c, authRepo, user)
|
||||
if err != nil {
|
||||
log.Printf("Token generation error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to generate tokens"})
|
||||
}
|
||||
|
||||
return c.JSON(tokens)
|
||||
})
|
||||
|
||||
api.Post("/refresh", func(c *fiber.Ctx) error {
|
||||
var req auth.RefreshRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.RefreshToken == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Refresh token is required"})
|
||||
}
|
||||
|
||||
rt, err := authRepo.ValidateRefreshToken(c.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrTokenExpired) {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Refresh token expired"})
|
||||
}
|
||||
if errors.Is(err, auth.ErrTokenInvalid) {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Invalid refresh token"})
|
||||
}
|
||||
log.Printf("Refresh error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Token refresh failed"})
|
||||
}
|
||||
|
||||
authRepo.RevokeRefreshToken(c.Context(), req.RefreshToken)
|
||||
|
||||
user, err := authRepo.GetUserByID(c.Context(), rt.UserID)
|
||||
if err != nil {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "User not found"})
|
||||
}
|
||||
|
||||
tokens, err := generateTokens(c, authRepo, user)
|
||||
if err != nil {
|
||||
log.Printf("Token generation error: %v", err)
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to generate tokens"})
|
||||
}
|
||||
|
||||
return c.JSON(tokens)
|
||||
})
|
||||
|
||||
api.Post("/logout", jwtMiddleware, func(c *fiber.Ctx) error {
|
||||
var req auth.RefreshRequest
|
||||
if err := c.BodyParser(&req); err == nil && req.RefreshToken != "" {
|
||||
authRepo.RevokeRefreshToken(c.Context(), req.RefreshToken)
|
||||
}
|
||||
return c.JSON(fiber.Map{"message": "Logged out successfully"})
|
||||
})
|
||||
|
||||
api.Post("/logout-all", jwtMiddleware, func(c *fiber.Ctx) error {
|
||||
userID := c.Locals("userId").(string)
|
||||
authRepo.RevokeAllRefreshTokens(c.Context(), userID)
|
||||
return c.JSON(fiber.Map{"message": "Logged out from all devices"})
|
||||
})
|
||||
|
||||
api.Get("/validate", func(c *fiber.Ctx) error {
|
||||
tokenString := extractToken(c)
|
||||
if tokenString == "" {
|
||||
return c.JSON(auth.ValidateResponse{Valid: false})
|
||||
}
|
||||
|
||||
claims, err := validateJWT(tokenString)
|
||||
if err != nil {
|
||||
return c.JSON(auth.ValidateResponse{Valid: false})
|
||||
}
|
||||
|
||||
user, err := authRepo.GetUserByID(c.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
return c.JSON(auth.ValidateResponse{Valid: false})
|
||||
}
|
||||
|
||||
return c.JSON(auth.ValidateResponse{
|
||||
Valid: true,
|
||||
User: user,
|
||||
})
|
||||
})
|
||||
|
||||
api.Get("/me", jwtMiddleware, func(c *fiber.Ctx) error {
|
||||
userID := c.Locals("userId").(string)
|
||||
|
||||
user, err := authRepo.GetUserByID(c.Context(), userID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "User not found"})
|
||||
}
|
||||
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
api.Put("/me", jwtMiddleware, func(c *fiber.Ctx) error {
|
||||
userID := c.Locals("userId").(string)
|
||||
|
||||
var req auth.UpdateProfileRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if err := authRepo.UpdateProfile(c.Context(), userID, req.Name, req.Avatar); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to update profile"})
|
||||
}
|
||||
|
||||
user, _ := authRepo.GetUserByID(c.Context(), userID)
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
api.Post("/change-password", jwtMiddleware, func(c *fiber.Ctx) error {
|
||||
userID := c.Locals("userId").(string)
|
||||
|
||||
var req auth.ChangePasswordRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.CurrentPassword == "" || req.NewPassword == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Current and new passwords are required"})
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "New password must be at least 8 characters"})
|
||||
}
|
||||
|
||||
user, err := authRepo.GetUserByID(c.Context(), userID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "User not found"})
|
||||
}
|
||||
|
||||
_, err = authRepo.ValidatePassword(c.Context(), user.Email, req.CurrentPassword)
|
||||
if err != nil {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Current password is incorrect"})
|
||||
}
|
||||
|
||||
if err := authRepo.UpdatePassword(c.Context(), userID, req.NewPassword); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to change password"})
|
||||
}
|
||||
|
||||
authRepo.RevokeAllRefreshTokens(c.Context(), userID)
|
||||
|
||||
return c.JSON(fiber.Map{"message": "Password changed successfully"})
|
||||
})
|
||||
|
||||
api.Post("/forgot-password", func(c *fiber.Ctx) error {
|
||||
var req auth.ResetPasswordRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.Email == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Email is required"})
|
||||
}
|
||||
|
||||
user, err := authRepo.GetUserByEmail(c.Context(), req.Email)
|
||||
if err == nil && user != nil {
|
||||
token, err := authRepo.CreatePasswordResetToken(c.Context(), user.ID)
|
||||
if err == nil {
|
||||
log.Printf("Password reset token for %s: %s", req.Email, token.Token)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"message": "If the email exists, a reset link has been sent"})
|
||||
})
|
||||
|
||||
api.Post("/reset-password", func(c *fiber.Ctx) error {
|
||||
var req auth.ResetPasswordConfirm
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.Token == "" || req.NewPassword == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Token and new password are required"})
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Password must be at least 8 characters"})
|
||||
}
|
||||
|
||||
prt, err := authRepo.ValidatePasswordResetToken(c.Context(), req.Token)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrTokenExpired) {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Reset token has expired"})
|
||||
}
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid reset token"})
|
||||
}
|
||||
|
||||
if err := authRepo.UpdatePassword(c.Context(), prt.UserID, req.NewPassword); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to reset password"})
|
||||
}
|
||||
|
||||
authRepo.MarkPasswordResetTokenUsed(c.Context(), prt.ID)
|
||||
authRepo.RevokeAllRefreshTokens(c.Context(), prt.UserID)
|
||||
|
||||
return c.JSON(fiber.Map{"message": "Password has been reset successfully"})
|
||||
})
|
||||
|
||||
port := config.GetEnvInt("AUTH_SVC_PORT", 3050)
|
||||
log.Printf("auth-svc listening on :%d", port)
|
||||
log.Fatal(app.Listen(fmt.Sprintf(":%d", port)))
|
||||
}
|
||||
|
||||
func generateTokens(c *fiber.Ctx, repo *auth.Repository, user *auth.User) (*auth.TokenResponse, error) {
|
||||
claims := JWTClaims{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
Tier: user.Tier,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenTTL)),
|
||||
Issuer: "gooseek",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
accessToken, err := token.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userAgent := c.Get("User-Agent")
|
||||
ip := c.IP()
|
||||
|
||||
refreshToken, err := repo.CreateRefreshToken(c.Context(), user.ID, userAgent, ip, refreshTokenTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &auth.TokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken.Token,
|
||||
ExpiresIn: int(accessTokenTTL.Seconds()),
|
||||
TokenType: "Bearer",
|
||||
User: user,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateJWT(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func extractToken(c *fiber.Ctx) string {
|
||||
auth := c.Get("Authorization")
|
||||
if len(auth) > 7 && auth[:7] == "Bearer " {
|
||||
return auth[7:]
|
||||
}
|
||||
return c.Query("token")
|
||||
}
|
||||
|
||||
func jwtMiddleware(c *fiber.Ctx) error {
|
||||
tokenString := extractToken(c)
|
||||
if tokenString == "" {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Missing authorization token"})
|
||||
}
|
||||
|
||||
claims, err := validateJWT(tokenString)
|
||||
if err != nil {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
c.Locals("userId", claims.UserID)
|
||||
c.Locals("userEmail", claims.Email)
|
||||
c.Locals("userRole", claims.Role)
|
||||
c.Locals("userTier", claims.Tier)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer/browser"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := 3050
|
||||
if p := os.Getenv("PORT"); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
port = parsed
|
||||
}
|
||||
}
|
||||
if p := os.Getenv("BROWSER_SVC_PORT"); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
port = parsed
|
||||
}
|
||||
}
|
||||
|
||||
cfg := browser.ServerConfig{
|
||||
Port: port,
|
||||
MaxSessions: 20,
|
||||
SessionTimeout: 30 * time.Minute,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
}
|
||||
|
||||
server := browser.NewBrowserServer(cfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-sigCh
|
||||
log.Println("[browser-svc] Shutting down...")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Printf("[browser-svc] Starting browser service on port %d", port)
|
||||
if err := server.Start(ctx); err != nil {
|
||||
log.Fatalf("[browser-svc] Server error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
)
|
||||
|
||||
type ChatRequest struct {
|
||||
@@ -93,7 +94,13 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/v1/chat", func(c *fiber.Ctx) error {
|
||||
chat := app.Group("/api/v1/chat", middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: false,
|
||||
}))
|
||||
|
||||
chat.Post("/", func(c *fiber.Ctx) error {
|
||||
var req ChatRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
|
||||
@@ -153,8 +153,11 @@ func main() {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
items, err := collectionRepo.GetItems(c.Context(), collectionID)
|
||||
items, err := collectionRepo.GetItems(c.Context(), collectionID, userID)
|
||||
if err != nil {
|
||||
if err == db.ErrForbidden {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to get items"})
|
||||
}
|
||||
collection.Items = items
|
||||
@@ -195,7 +198,10 @@ func main() {
|
||||
collection.IsPublic = req.IsPublic
|
||||
collection.ContextEnabled = req.ContextEnabled
|
||||
|
||||
if err := collectionRepo.Update(c.Context(), collection); err != nil {
|
||||
if err := collectionRepo.Update(c.Context(), collection, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Collection not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to update collection"})
|
||||
}
|
||||
|
||||
@@ -210,16 +216,10 @@ func main() {
|
||||
collectionID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
collection, err := collectionRepo.GetByID(c.Context(), collectionID)
|
||||
if err != nil || collection == nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Collection not found"})
|
||||
}
|
||||
|
||||
if collection.UserID != userID {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
if err := collectionRepo.Delete(c.Context(), collectionID); err != nil {
|
||||
if err := collectionRepo.Delete(c.Context(), collectionID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Collection not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete collection"})
|
||||
}
|
||||
|
||||
@@ -293,7 +293,10 @@ func main() {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
if err := collectionRepo.RemoveItem(c.Context(), itemID); err != nil {
|
||||
if err := collectionRepo.RemoveItem(c.Context(), itemID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Item not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to remove item"})
|
||||
}
|
||||
|
||||
@@ -321,7 +324,7 @@ func main() {
|
||||
return c.JSON(fiber.Map{"context": "", "enabled": false})
|
||||
}
|
||||
|
||||
context, err := collectionRepo.GetCollectionContext(c.Context(), collectionID)
|
||||
context, err := collectionRepo.GetCollectionContext(c.Context(), collectionID, userID)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to get context"})
|
||||
}
|
||||
|
||||
@@ -1,552 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
"github.com/gooseek/backend/internal/computer/connectors"
|
||||
"github.com/gooseek/backend/internal/db"
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
var database *db.PostgresDB
|
||||
maxRetries := 30
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
database, err = db.NewPostgresDB(cfg.DatabaseURL)
|
||||
if err == nil {
|
||||
log.Println("PostgreSQL connected successfully")
|
||||
break
|
||||
}
|
||||
log.Printf("Waiting for database (attempt %d/%d): %v", i+1, maxRetries, err)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database after %d attempts: %v", maxRetries, err)
|
||||
}
|
||||
|
||||
taskRepo := db.NewComputerTaskRepo(database.DB())
|
||||
memoryRepo := db.NewComputerMemoryRepo(database.DB())
|
||||
artifactRepo := db.NewComputerArtifactRepo(database.DB())
|
||||
|
||||
if err := taskRepo.Migrate(); err != nil {
|
||||
log.Printf("Task repo migration warning: %v", err)
|
||||
}
|
||||
if err := memoryRepo.Migrate(); err != nil {
|
||||
log.Printf("Memory repo migration warning: %v", err)
|
||||
}
|
||||
if err := artifactRepo.Migrate(); err != nil {
|
||||
log.Printf("Artifact repo migration warning: %v", err)
|
||||
}
|
||||
|
||||
registry := llm.NewModelRegistry()
|
||||
setupModels(registry, cfg)
|
||||
|
||||
connectorHub := connectors.NewConnectorHub()
|
||||
setupConnectors(connectorHub, cfg)
|
||||
|
||||
comp := computer.NewComputer(computer.ComputerConfig{
|
||||
MaxParallelTasks: 10,
|
||||
MaxSubTasks: 20,
|
||||
TaskTimeout: 30 * time.Minute,
|
||||
SubTaskTimeout: 5 * time.Minute,
|
||||
TotalBudget: 1.0,
|
||||
EnableSandbox: true,
|
||||
EnableScheduling: true,
|
||||
SandboxImage: getEnv("SANDBOX_IMAGE", "gooseek/sandbox:latest"),
|
||||
}, computer.Dependencies{
|
||||
Registry: registry,
|
||||
TaskRepo: taskRepo,
|
||||
MemoryRepo: memoryRepo,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
comp.StartScheduler(ctx)
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
app.Use(recover.New())
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: "*",
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
|
||||
AllowMethods: "GET, POST, PUT, DELETE, OPTIONS",
|
||||
}))
|
||||
app.Use(middleware.Logging(middleware.LoggingConfig{}))
|
||||
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"status": "ok",
|
||||
"service": "computer-svc",
|
||||
"models": registry.Count(),
|
||||
})
|
||||
})
|
||||
|
||||
api := app.Group("/api/v1/computer")
|
||||
|
||||
api.Post("/execute", func(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
Query string `json:"query"`
|
||||
UserID string `json:"userId"`
|
||||
Options computer.ExecuteOptions `json:"options"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
if req.Query == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "query is required"})
|
||||
}
|
||||
|
||||
if req.UserID == "" || req.UserID == "anonymous" {
|
||||
req.UserID = "00000000-0000-0000-0000-000000000000"
|
||||
}
|
||||
|
||||
task, err := comp.Execute(c.Context(), req.UserID, req.Query, req.Options)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(task)
|
||||
})
|
||||
|
||||
api.Get("/tasks", func(c *fiber.Ctx) error {
|
||||
userID := c.Query("userId", "")
|
||||
limit := c.QueryInt("limit", 20)
|
||||
offset := c.QueryInt("offset", 0)
|
||||
|
||||
if userID == "" || userID == "anonymous" {
|
||||
userID = "00000000-0000-0000-0000-000000000000"
|
||||
}
|
||||
|
||||
tasks, err := comp.GetUserTasks(c.Context(), userID, limit, offset)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"tasks": tasks,
|
||||
"count": len(tasks),
|
||||
})
|
||||
})
|
||||
|
||||
api.Get("/tasks/:id", func(c *fiber.Ctx) error {
|
||||
taskID := c.Params("id")
|
||||
|
||||
task, err := comp.GetStatus(c.Context(), taskID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "task not found"})
|
||||
}
|
||||
|
||||
return c.JSON(task)
|
||||
})
|
||||
|
||||
api.Get("/tasks/:id/stream", func(c *fiber.Ctx) error {
|
||||
taskID := c.Params("id")
|
||||
|
||||
c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
|
||||
eventCh, err := comp.Stream(c.Context(), taskID)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
for event := range eventCh {
|
||||
data, _ := json.Marshal(event)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
w.Flush()
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
api.Post("/tasks/:id/resume", func(c *fiber.Ctx) error {
|
||||
taskID := c.Params("id")
|
||||
|
||||
var req struct {
|
||||
UserInput string `json:"userInput"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
if err := comp.Resume(c.Context(), taskID, req.UserInput); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"status": "resumed"})
|
||||
})
|
||||
|
||||
api.Delete("/tasks/:id", func(c *fiber.Ctx) error {
|
||||
taskID := c.Params("id")
|
||||
|
||||
if err := comp.Cancel(c.Context(), taskID); err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"status": "cancelled"})
|
||||
})
|
||||
|
||||
api.Get("/tasks/:id/artifacts", func(c *fiber.Ctx) error {
|
||||
taskID := c.Params("id")
|
||||
|
||||
artifacts, err := artifactRepo.GetByTaskID(c.Context(), taskID)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"artifacts": artifacts,
|
||||
"count": len(artifacts),
|
||||
})
|
||||
})
|
||||
|
||||
api.Get("/artifacts/:id", func(c *fiber.Ctx) error {
|
||||
artifactID := c.Params("id")
|
||||
|
||||
artifact, err := artifactRepo.GetByID(c.Context(), artifactID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "artifact not found"})
|
||||
}
|
||||
|
||||
return c.JSON(artifact)
|
||||
})
|
||||
|
||||
api.Get("/artifacts/:id/download", func(c *fiber.Ctx) error {
|
||||
artifactID := c.Params("id")
|
||||
|
||||
artifact, err := artifactRepo.GetByID(c.Context(), artifactID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "artifact not found"})
|
||||
}
|
||||
|
||||
if artifact.MimeType != "" {
|
||||
c.Set("Content-Type", artifact.MimeType)
|
||||
} else {
|
||||
c.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
c.Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", artifact.Name))
|
||||
|
||||
return c.Send(artifact.Content)
|
||||
})
|
||||
|
||||
api.Get("/models", func(c *fiber.Ctx) error {
|
||||
models := registry.GetAll()
|
||||
return c.JSON(fiber.Map{
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
})
|
||||
})
|
||||
|
||||
api.Get("/connectors", func(c *fiber.Ctx) error {
|
||||
info := connectorHub.GetInfo()
|
||||
return c.JSON(fiber.Map{
|
||||
"connectors": info,
|
||||
"count": len(info),
|
||||
})
|
||||
})
|
||||
|
||||
api.Post("/connectors/:id/execute", func(c *fiber.Ctx) error {
|
||||
connectorID := c.Params("id")
|
||||
|
||||
var req struct {
|
||||
Action string `json:"action"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
result, err := connectorHub.Execute(c.Context(), connectorID, req.Action, req.Params)
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(result)
|
||||
})
|
||||
|
||||
port := getEnv("COMPUTER_SVC_PORT", "3030")
|
||||
addr := ":" + port
|
||||
|
||||
go func() {
|
||||
log.Printf("Computer service starting on %s", addr)
|
||||
if err := app.Listen(addr); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down...")
|
||||
comp.StopScheduler()
|
||||
app.Shutdown()
|
||||
}
|
||||
|
||||
func setupModels(registry *llm.ModelRegistry, cfg *config.Config) {
|
||||
// Timeweb Cloud AI (приоритетный провайдер для России)
|
||||
if cfg.TimewebAgentAccessID != "" && cfg.TimewebAPIKey != "" {
|
||||
timewebClient, err := llm.NewTimewebClient(llm.TimewebConfig{
|
||||
ProviderID: "timeweb",
|
||||
ModelKey: "gpt-4o",
|
||||
BaseURL: cfg.TimewebAPIBaseURL,
|
||||
AgentAccessID: cfg.TimewebAgentAccessID,
|
||||
APIKey: cfg.TimewebAPIKey,
|
||||
ProxySource: cfg.TimewebProxySource,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "timeweb-gpt-4o",
|
||||
Provider: "timeweb",
|
||||
Model: "gpt-4o",
|
||||
Capabilities: []llm.ModelCapability{llm.CapSearch, llm.CapFast, llm.CapVision, llm.CapCoding, llm.CapCreative, llm.CapReasoning},
|
||||
CostPer1K: 0.005,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 0,
|
||||
Description: "GPT-4o via Timeweb Cloud AI",
|
||||
}, timewebClient)
|
||||
log.Println("Timeweb GPT-4o registered")
|
||||
} else {
|
||||
log.Printf("Failed to create Timeweb client: %v", err)
|
||||
}
|
||||
|
||||
timewebMiniClient, err := llm.NewTimewebClient(llm.TimewebConfig{
|
||||
ProviderID: "timeweb",
|
||||
ModelKey: "gpt-4o-mini",
|
||||
BaseURL: cfg.TimewebAPIBaseURL,
|
||||
AgentAccessID: cfg.TimewebAgentAccessID,
|
||||
APIKey: cfg.TimewebAPIKey,
|
||||
ProxySource: cfg.TimewebProxySource,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "timeweb-gpt-4o-mini",
|
||||
Provider: "timeweb",
|
||||
Model: "gpt-4o-mini",
|
||||
Capabilities: []llm.ModelCapability{llm.CapFast, llm.CapCoding},
|
||||
CostPer1K: 0.00015,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 0,
|
||||
Description: "GPT-4o-mini via Timeweb Cloud AI",
|
||||
}, timewebMiniClient)
|
||||
log.Println("Timeweb GPT-4o-mini registered")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI прямой (fallback если Timeweb недоступен)
|
||||
if cfg.OpenAIAPIKey != "" {
|
||||
openaiClient, err := llm.NewOpenAIClient(llm.ProviderConfig{
|
||||
ProviderID: "openai",
|
||||
ModelKey: "gpt-4o",
|
||||
APIKey: cfg.OpenAIAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
Capabilities: []llm.ModelCapability{llm.CapSearch, llm.CapFast, llm.CapVision, llm.CapCoding, llm.CapCreative},
|
||||
CostPer1K: 0.005,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 10,
|
||||
}, openaiClient)
|
||||
}
|
||||
|
||||
miniClient, err := llm.NewOpenAIClient(llm.ProviderConfig{
|
||||
ProviderID: "openai",
|
||||
ModelKey: "gpt-4o-mini",
|
||||
APIKey: cfg.OpenAIAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Capabilities: []llm.ModelCapability{llm.CapFast, llm.CapCoding},
|
||||
CostPer1K: 0.00015,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 10,
|
||||
}, miniClient)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.AnthropicAPIKey != "" {
|
||||
opusClient, err := llm.NewAnthropicClient(llm.ProviderConfig{
|
||||
ProviderID: "anthropic",
|
||||
ModelKey: "claude-3-opus-20240229",
|
||||
APIKey: cfg.AnthropicAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "claude-3-opus",
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-opus-20240229",
|
||||
Capabilities: []llm.ModelCapability{llm.CapReasoning, llm.CapCoding, llm.CapCreative, llm.CapLongContext},
|
||||
CostPer1K: 0.015,
|
||||
MaxContext: 200000,
|
||||
MaxTokens: 4096,
|
||||
Priority: 1,
|
||||
}, opusClient)
|
||||
}
|
||||
|
||||
sonnetClient, err := llm.NewAnthropicClient(llm.ProviderConfig{
|
||||
ProviderID: "anthropic",
|
||||
ModelKey: "claude-3-5-sonnet-20241022",
|
||||
APIKey: cfg.AnthropicAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "claude-3-sonnet",
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Capabilities: []llm.ModelCapability{llm.CapCoding, llm.CapCreative, llm.CapFast},
|
||||
CostPer1K: 0.003,
|
||||
MaxContext: 200000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 1,
|
||||
}, sonnetClient)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.GeminiAPIKey != "" {
|
||||
geminiClient, err := llm.NewGeminiClient(llm.ProviderConfig{
|
||||
ProviderID: "gemini",
|
||||
ModelKey: "gemini-1.5-pro",
|
||||
APIKey: cfg.GeminiAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "gemini-1.5-pro",
|
||||
Provider: "gemini",
|
||||
Model: "gemini-1.5-pro",
|
||||
Capabilities: []llm.ModelCapability{llm.CapLongContext, llm.CapSearch, llm.CapVision, llm.CapMath},
|
||||
CostPer1K: 0.00125,
|
||||
MaxContext: 2000000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 1,
|
||||
}, geminiClient)
|
||||
}
|
||||
|
||||
flashClient, err := llm.NewGeminiClient(llm.ProviderConfig{
|
||||
ProviderID: "gemini",
|
||||
ModelKey: "gemini-1.5-flash",
|
||||
APIKey: cfg.GeminiAPIKey,
|
||||
})
|
||||
if err == nil {
|
||||
registry.Register(llm.ModelSpec{
|
||||
ID: "gemini-1.5-flash",
|
||||
Provider: "gemini",
|
||||
Model: "gemini-1.5-flash",
|
||||
Capabilities: []llm.ModelCapability{llm.CapFast, llm.CapVision},
|
||||
CostPer1K: 0.000075,
|
||||
MaxContext: 1000000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 2,
|
||||
}, flashClient)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Registered %d models", registry.Count())
|
||||
}
|
||||
|
||||
func setupConnectors(hub *connectors.ConnectorHub, cfg *config.Config) {
|
||||
if smtpHost := getEnv("SMTP_HOST", ""); smtpHost != "" {
|
||||
emailConn := connectors.NewEmailConnector(connectors.EmailConfig{
|
||||
SMTPHost: smtpHost,
|
||||
SMTPPort: getEnvInt("SMTP_PORT", 587),
|
||||
Username: getEnv("SMTP_USERNAME", ""),
|
||||
Password: getEnv("SMTP_PASSWORD", ""),
|
||||
FromAddress: getEnv("SMTP_FROM", ""),
|
||||
FromName: getEnv("SMTP_FROM_NAME", "GooSeek Computer"),
|
||||
UseTLS: true,
|
||||
AllowHTML: true,
|
||||
})
|
||||
hub.Register(emailConn)
|
||||
log.Println("Email connector registered")
|
||||
}
|
||||
|
||||
if botToken := getEnv("TELEGRAM_BOT_TOKEN", ""); botToken != "" {
|
||||
tgConn := connectors.NewTelegramConnector(connectors.TelegramConfig{
|
||||
BotToken: botToken,
|
||||
})
|
||||
hub.Register(tgConn)
|
||||
log.Println("Telegram connector registered")
|
||||
}
|
||||
|
||||
webhookConn := connectors.NewWebhookConnector(connectors.WebhookConfig{
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
hub.Register(webhookConn)
|
||||
log.Println("Webhook connector registered")
|
||||
|
||||
if s3Endpoint := getEnv("S3_ENDPOINT", ""); s3Endpoint != "" {
|
||||
storageConn, err := connectors.NewStorageConnector(connectors.StorageConfig{
|
||||
Endpoint: s3Endpoint,
|
||||
AccessKeyID: getEnv("S3_ACCESS_KEY", ""),
|
||||
SecretAccessKey: getEnv("S3_SECRET_KEY", ""),
|
||||
BucketName: getEnv("S3_BUCKET", "gooseek-artifacts"),
|
||||
UseSSL: getEnv("S3_USE_SSL", "true") == "true",
|
||||
Region: getEnv("S3_REGION", "us-east-1"),
|
||||
PublicURL: getEnv("S3_PUBLIC_URL", ""),
|
||||
})
|
||||
if err == nil {
|
||||
hub.Register(storageConn)
|
||||
log.Println("Storage connector registered")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var i int
|
||||
fmt.Sscanf(value, "%d", &i)
|
||||
return i
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gooseek/backend/internal/search"
|
||||
"github.com/gooseek/backend/pkg/cache"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
)
|
||||
|
||||
type DigestCitation struct {
|
||||
@@ -237,7 +238,9 @@ func main() {
|
||||
)
|
||||
})
|
||||
|
||||
app.Get("/api/v1/discover/digest", func(c *fiber.Ctx) error {
|
||||
discover := app.Group("/api/v1/discover")
|
||||
|
||||
discover.Get("/digest", func(c *fiber.Ctx) error {
|
||||
url := c.Query("url")
|
||||
if url != "" {
|
||||
digest := store.GetDigestByURL(url)
|
||||
@@ -263,7 +266,13 @@ func main() {
|
||||
return c.JSON(digest)
|
||||
})
|
||||
|
||||
app.Post("/api/v1/discover/digest", func(c *fiber.Ctx) error {
|
||||
discoverAuth := app.Group("/api/v1/discover", middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: false,
|
||||
}))
|
||||
|
||||
discoverAuth.Post("/digest", func(c *fiber.Ctx) error {
|
||||
var d Digest
|
||||
if err := c.BodyParser(&d); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
@@ -277,7 +286,7 @@ func main() {
|
||||
return c.Status(204).Send(nil)
|
||||
})
|
||||
|
||||
app.Delete("/api/v1/discover/digest", func(c *fiber.Ctx) error {
|
||||
discoverAuth.Delete("/digest", func(c *fiber.Ctx) error {
|
||||
topic := c.Query("topic")
|
||||
region := c.Query("region")
|
||||
|
||||
@@ -289,7 +298,7 @@ func main() {
|
||||
return c.JSON(fiber.Map{"deleted": deleted})
|
||||
})
|
||||
|
||||
app.Get("/api/v1/discover/article-summary", func(c *fiber.Ctx) error {
|
||||
discover.Get("/article-summary", func(c *fiber.Ctx) error {
|
||||
url := c.Query("url")
|
||||
if url == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"message": "url required"})
|
||||
@@ -320,7 +329,7 @@ func main() {
|
||||
return c.JSON(fiber.Map{"events": summary.Events})
|
||||
})
|
||||
|
||||
app.Post("/api/v1/discover/article-summary", func(c *fiber.Ctx) error {
|
||||
discoverAuth.Post("/article-summary", func(c *fiber.Ctx) error {
|
||||
var body struct {
|
||||
URL string `json:"url"`
|
||||
Events []string `json:"events"`
|
||||
@@ -354,7 +363,7 @@ func main() {
|
||||
return c.Status(204).Send(nil)
|
||||
})
|
||||
|
||||
app.Delete("/api/v1/discover/article-summary", func(c *fiber.Ctx) error {
|
||||
discoverAuth.Delete("/article-summary", func(c *fiber.Ctx) error {
|
||||
url := c.Query("url")
|
||||
if url == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"message": "url required"})
|
||||
@@ -365,7 +374,7 @@ func main() {
|
||||
return c.Status(204).Send(nil)
|
||||
})
|
||||
|
||||
app.Get("/api/v1/discover/search", func(c *fiber.Ctx) error {
|
||||
discover.Get("/search", func(c *fiber.Ctx) error {
|
||||
q := c.Query("q")
|
||||
if q == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"message": "Query q is required"})
|
||||
@@ -386,14 +395,38 @@ func main() {
|
||||
return c.JSON(fiber.Map{"results": result.Results})
|
||||
})
|
||||
|
||||
app.Get("/api/v1/discover", func(c *fiber.Ctx) error {
|
||||
discover.Get("/", func(c *fiber.Ctx) error {
|
||||
topic := c.Query("topic", "tech")
|
||||
region := c.Query("region", "world")
|
||||
page := c.QueryInt("page", 1)
|
||||
limit := c.QueryInt("limit", 10)
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if limit < 1 || limit > 30 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
digests := store.GetDigests(topic, region)
|
||||
if len(digests) > 0 {
|
||||
blogs := make([]fiber.Map, len(digests))
|
||||
for i, d := range digests {
|
||||
start := (page - 1) * limit
|
||||
end := start + limit
|
||||
if start >= len(digests) {
|
||||
return c.JSON(fiber.Map{
|
||||
"blogs": []fiber.Map{},
|
||||
"hasMore": false,
|
||||
"page": page,
|
||||
"total": len(digests),
|
||||
})
|
||||
}
|
||||
if end > len(digests) {
|
||||
end = len(digests)
|
||||
}
|
||||
|
||||
pagedDigests := digests[start:end]
|
||||
blogs := make([]fiber.Map, len(pagedDigests))
|
||||
for i, d := range pagedDigests {
|
||||
content := d.ShortDescription
|
||||
if content == "" && len(d.SummaryRu) > 200 {
|
||||
content = d.SummaryRu[:200] + "…"
|
||||
@@ -410,7 +443,12 @@ func main() {
|
||||
"digestId": fmt.Sprintf("%s:%s:%s", d.Topic, d.Region, d.ClusterTitle),
|
||||
}
|
||||
}
|
||||
return c.JSON(fiber.Map{"blogs": blogs})
|
||||
return c.JSON(fiber.Map{
|
||||
"blogs": blogs,
|
||||
"hasMore": end < len(digests),
|
||||
"page": page,
|
||||
"total": len(digests),
|
||||
})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.SearchTimeout*2)
|
||||
@@ -419,15 +457,15 @@ func main() {
|
||||
queries := getQueriesForTopic(topic, region)
|
||||
results, err := searchClient.Search(ctx, queries[0], &search.SearchOptions{
|
||||
Categories: []string{"news"},
|
||||
PageNo: 1,
|
||||
PageNo: page,
|
||||
})
|
||||
if err != nil {
|
||||
return c.Status(503).JSON(fiber.Map{"message": "Search failed"})
|
||||
}
|
||||
|
||||
blogs := make([]fiber.Map, 0, 7)
|
||||
blogs := make([]fiber.Map, 0, limit)
|
||||
for i, r := range results.Results {
|
||||
if i >= 7 {
|
||||
if i >= limit {
|
||||
break
|
||||
}
|
||||
thumbnail := r.Thumbnail
|
||||
@@ -454,7 +492,12 @@ func main() {
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"blogs": blogs})
|
||||
hasMore := len(results.Results) > limit
|
||||
return c.JSON(fiber.Map{
|
||||
"blogs": blogs,
|
||||
"hasMore": hasMore,
|
||||
"page": page,
|
||||
})
|
||||
})
|
||||
|
||||
port := getEnvInt("DISCOVER_SVC_PORT", 3002)
|
||||
@@ -466,19 +509,54 @@ func getQueriesForTopic(topic, region string) []string {
|
||||
queries := map[string]map[string][]string{
|
||||
"tech": {
|
||||
"world": {"technology news AI innovation"},
|
||||
"russia": {"технологии новости IT инновации"},
|
||||
"russia": {"технологии новости IT инновации искусственный интеллект"},
|
||||
"eu": {"technology news Europe AI"},
|
||||
},
|
||||
"finance": {
|
||||
"world": {"finance news economy markets"},
|
||||
"russia": {"финансы новости экономика рынки"},
|
||||
"eu": {"finance news Europe economy"},
|
||||
"world": {"finance news economy markets stocks"},
|
||||
"russia": {"финансы новости экономика рынки акции"},
|
||||
"eu": {"finance news Europe economy markets"},
|
||||
},
|
||||
"sports": {
|
||||
"world": {"sports news football Olympics"},
|
||||
"russia": {"спорт новости футбол хоккей"},
|
||||
"world": {"sports news football basketball Olympics"},
|
||||
"russia": {"спорт новости футбол хоккей КХЛ РПЛ"},
|
||||
"eu": {"sports news football Champions League"},
|
||||
},
|
||||
"politics": {
|
||||
"world": {"politics news government elections policy"},
|
||||
"russia": {"политика новости Россия правительство законы"},
|
||||
"eu": {"politics news Europe EU parliament"},
|
||||
},
|
||||
"science": {
|
||||
"world": {"science news research discovery space"},
|
||||
"russia": {"наука новости исследования открытия космос"},
|
||||
"eu": {"science news Europe research discovery"},
|
||||
},
|
||||
"health": {
|
||||
"world": {"health news medicine medical research"},
|
||||
"russia": {"здоровье новости медицина лечение"},
|
||||
"eu": {"health news Europe medicine healthcare"},
|
||||
},
|
||||
"entertainment": {
|
||||
"world": {"entertainment news movies music celebrities"},
|
||||
"russia": {"развлечения новости кино музыка шоу-бизнес"},
|
||||
"eu": {"entertainment news Europe movies music"},
|
||||
},
|
||||
"world": {
|
||||
"world": {"world news international global events"},
|
||||
"russia": {"мировые новости международные события"},
|
||||
"eu": {"world news Europe international"},
|
||||
},
|
||||
"business": {
|
||||
"world": {"business news companies startups industry"},
|
||||
"russia": {"бизнес новости компании стартапы предпринимательство"},
|
||||
"eu": {"business news Europe companies industry"},
|
||||
},
|
||||
"culture": {
|
||||
"world": {"culture news art exhibitions theatre"},
|
||||
"russia": {"культура новости искусство выставки театр"},
|
||||
"eu": {"culture news Europe art exhibitions"},
|
||||
},
|
||||
}
|
||||
|
||||
if topicQueries, ok := queries[topic]; ok {
|
||||
|
||||
@@ -161,7 +161,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
fileRepo.UpdateExtractedText(ctx, uploadedFile.ID, result.ExtractedText)
|
||||
fileRepo.UpdateExtractedText(ctx, uploadedFile.ID, result.ExtractedText, uploadedFile.UserID)
|
||||
}()
|
||||
|
||||
return c.Status(201).JSON(fiber.Map{
|
||||
@@ -260,7 +260,7 @@ func main() {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Analysis failed: " + err.Error()})
|
||||
}
|
||||
|
||||
fileRepo.UpdateExtractedText(c.Context(), fileID, result.ExtractedText)
|
||||
fileRepo.UpdateExtractedText(c.Context(), fileID, result.ExtractedText, userID)
|
||||
|
||||
return c.JSON(result)
|
||||
})
|
||||
@@ -284,7 +284,10 @@ func main() {
|
||||
|
||||
fileAnalyzer.DeleteFile(file.StoragePath)
|
||||
|
||||
if err := fileRepo.Delete(c.Context(), fileID); err != nil {
|
||||
if err := fileRepo.Delete(c.Context(), fileID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "File not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete file"})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -12,8 +13,11 @@ import (
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/usage"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
"github.com/gooseek/backend/pkg/middleware"
|
||||
"github.com/gooseek/backend/pkg/ndjson"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type GenerateRequest struct {
|
||||
@@ -36,6 +40,26 @@ func main() {
|
||||
log.Fatal("Failed to load config:", err)
|
||||
}
|
||||
|
||||
var usageRepo *usage.Repository
|
||||
if cfg.DatabaseURL != "" {
|
||||
db, err := sql.Open("postgres", cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Printf("Usage tracking unavailable: %v", err)
|
||||
} else {
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(2)
|
||||
defer db.Close()
|
||||
|
||||
usageRepo = usage.NewRepository(db)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if err := usageRepo.RunMigrations(ctx); err != nil {
|
||||
log.Printf("Usage migrations warning: %v", err)
|
||||
}
|
||||
cancel()
|
||||
log.Println("Usage tracking enabled")
|
||||
}
|
||||
}
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
StreamRequestBody: true,
|
||||
BodyLimit: 10 * 1024 * 1024,
|
||||
@@ -90,7 +114,20 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/v1/generate", func(c *fiber.Ctx) error {
|
||||
llmAPI := app.Group("/api/v1", middleware.JWT(middleware.JWTConfig{
|
||||
Secret: cfg.JWTSecret,
|
||||
AuthSvcURL: cfg.AuthSvcURL,
|
||||
AllowGuest: false,
|
||||
}), middleware.LLMLimits(middleware.LLMLimitsConfig{
|
||||
UsageRepo: usageRepo,
|
||||
}))
|
||||
|
||||
llmAPI.Post("/generate", func(c *fiber.Ctx) error {
|
||||
userID := middleware.GetUserID(c)
|
||||
tier := middleware.GetUserTier(c)
|
||||
if tier == "" {
|
||||
tier = "free"
|
||||
}
|
||||
var req GenerateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request body"})
|
||||
@@ -100,6 +137,11 @@ func main() {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Messages required"})
|
||||
}
|
||||
|
||||
limits := usage.GetLimits(tier)
|
||||
if req.Options.MaxTokens == 0 || req.Options.MaxTokens > limits.MaxTokensPerReq {
|
||||
req.Options.MaxTokens = limits.MaxTokensPerReq
|
||||
}
|
||||
|
||||
client, err := llm.NewClient(llm.ProviderConfig{
|
||||
ProviderID: req.ProviderID,
|
||||
ModelKey: req.ModelKey,
|
||||
@@ -161,11 +203,19 @@ func main() {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
if usageRepo != nil {
|
||||
go usageRepo.IncrementLLMUsage(context.Background(), userID, tier, len(response)/4)
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"content": response,
|
||||
})
|
||||
})
|
||||
|
||||
llmAPI.Post("/embed", func(c *fiber.Ctx) error {
|
||||
return c.Status(501).JSON(fiber.Map{"error": "Not implemented"})
|
||||
})
|
||||
|
||||
port := cfg.LLMSvcPort
|
||||
log.Printf("llm-svc listening on :%d", port)
|
||||
log.Fatal(app.Listen(fmt.Sprintf(":%d", port)))
|
||||
|
||||
@@ -181,7 +181,10 @@ func main() {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
messages, _ := threadRepo.GetMessages(c.Context(), threadID, 100, 0)
|
||||
messages, err := threadRepo.GetMessages(c.Context(), threadID, userID, 100, 0)
|
||||
if err != nil && err != db.ErrForbidden {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to get messages"})
|
||||
}
|
||||
thread.Messages = messages
|
||||
|
||||
return c.JSON(thread)
|
||||
@@ -225,12 +228,15 @@ func main() {
|
||||
TokensUsed: req.TokensUsed,
|
||||
}
|
||||
|
||||
if err := threadRepo.AddMessage(c.Context(), msg); err != nil {
|
||||
if err := threadRepo.AddMessage(c.Context(), msg, userID); err != nil {
|
||||
if err == db.ErrForbidden {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to add message"})
|
||||
}
|
||||
|
||||
if thread.Title == "New Thread" && req.Role == "user" {
|
||||
threadRepo.GenerateTitle(c.Context(), threadID, req.Content)
|
||||
threadRepo.GenerateTitle(c.Context(), threadID, req.Content, userID)
|
||||
}
|
||||
|
||||
return c.Status(201).JSON(msg)
|
||||
@@ -250,7 +256,10 @@ func main() {
|
||||
}
|
||||
|
||||
shareID := generateShareID()
|
||||
if err := threadRepo.SetShareID(c.Context(), threadID, shareID); err != nil {
|
||||
if err := threadRepo.SetShareID(c.Context(), threadID, shareID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Thread not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to share thread"})
|
||||
}
|
||||
|
||||
@@ -264,16 +273,10 @@ func main() {
|
||||
threadID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
thread, err := threadRepo.GetByID(c.Context(), threadID)
|
||||
if err != nil || thread == nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Thread not found"})
|
||||
}
|
||||
|
||||
if thread.UserID != userID {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
if err := threadRepo.Delete(c.Context(), threadID); err != nil {
|
||||
if err := threadRepo.Delete(c.Context(), threadID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Thread not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete thread"})
|
||||
}
|
||||
|
||||
@@ -290,7 +293,7 @@ func main() {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Shared thread not found"})
|
||||
}
|
||||
|
||||
messages, _ := threadRepo.GetMessages(c.Context(), thread.ID, 100, 0)
|
||||
messages, _ := threadRepo.GetMessages(c.Context(), thread.ID, thread.UserID, 100, 0)
|
||||
thread.Messages = messages
|
||||
|
||||
return c.JSON(thread)
|
||||
@@ -353,15 +356,6 @@ func main() {
|
||||
spaceID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
space, err := spaceRepo.GetByID(c.Context(), spaceID)
|
||||
if err != nil || space == nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Space not found"})
|
||||
}
|
||||
|
||||
if space.UserID != userID {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
var req db.Space
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request"})
|
||||
@@ -370,7 +364,10 @@ func main() {
|
||||
req.ID = spaceID
|
||||
req.UserID = userID
|
||||
|
||||
if err := spaceRepo.Update(c.Context(), &req); err != nil {
|
||||
if err := spaceRepo.Update(c.Context(), &req, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Space not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to update space"})
|
||||
}
|
||||
|
||||
@@ -381,16 +378,10 @@ func main() {
|
||||
spaceID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
space, err := spaceRepo.GetByID(c.Context(), spaceID)
|
||||
if err != nil || space == nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Space not found"})
|
||||
}
|
||||
|
||||
if space.UserID != userID {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
if err := spaceRepo.Delete(c.Context(), spaceID); err != nil {
|
||||
if err := spaceRepo.Delete(c.Context(), spaceID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Space not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete space"})
|
||||
}
|
||||
|
||||
@@ -445,8 +436,12 @@ func main() {
|
||||
|
||||
memory.Delete("/:id", func(c *fiber.Ctx) error {
|
||||
memID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
if err := memoryRepo.Delete(c.Context(), memID); err != nil {
|
||||
if err := memoryRepo.Delete(c.Context(), memID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Memory not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete memory"})
|
||||
}
|
||||
|
||||
@@ -493,7 +488,7 @@ func main() {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
messages, _ := threadRepo.GetMessages(c.Context(), threadID, 100, 0)
|
||||
messages, _ := threadRepo.GetMessages(c.Context(), threadID, userID, 100, 0)
|
||||
|
||||
var query, answer string
|
||||
for _, msg := range messages {
|
||||
@@ -559,7 +554,10 @@ func main() {
|
||||
}
|
||||
|
||||
shareID := generateShareID()
|
||||
if err := pageRepo.SetShareID(c.Context(), pageID, shareID); err != nil {
|
||||
if err := pageRepo.SetShareID(c.Context(), pageID, shareID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Page not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to share page"})
|
||||
}
|
||||
|
||||
@@ -586,16 +584,10 @@ func main() {
|
||||
pageID := c.Params("id")
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
page, err := pageRepo.GetByID(c.Context(), pageID)
|
||||
if err != nil || page == nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Page not found"})
|
||||
}
|
||||
|
||||
if page.UserID != userID {
|
||||
return c.Status(403).JSON(fiber.Map{"error": "Access denied"})
|
||||
}
|
||||
|
||||
if err := pageRepo.Delete(c.Context(), pageID); err != nil {
|
||||
if err := pageRepo.Delete(c.Context(), pageID, userID); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Page not found"})
|
||||
}
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Failed to delete page"})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user