feat: auth service + security audit fixes + cleanup legacy services
Major changes:
- Add auth-svc: JWT auth, register/login/refresh, password reset
- Add auth UI: modals, pages (/login, /register, /forgot-password)
- Add usage tracking (usage_metrics table, daily limits)
- Add tiered rate limiting (free/pro/business)
- Add LLM usage limits per tier
Security fixes:
- All repos now require userID for Update/Delete operations
- JWT middleware in chat-svc, llm-svc, agent-svc, discover-svc
- ErrNotFound/ErrForbidden errors for proper access control
Cleanup:
- Remove legacy TypeScript services/ directory
- Remove computer-svc (to be reimplemented)
- Remove old deploy/docker configs
New files:
- backend/cmd/auth-svc/main.go
- backend/internal/auth/{types,repository}.go
- backend/internal/usage/{types,repository}.go
- backend/pkg/middleware/{llm_limits,ratelimit_tiered}.go
- backend/webui/src/components/auth/*
- backend/webui/src/app/(auth)/*
Made-with: Cursor
This commit is contained in:
@@ -105,21 +105,35 @@ func (r *CollectionRepository) GetByUserID(ctx context.Context, userID string, l
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) Update(ctx context.Context, c *Collection) error {
|
||||
func (r *CollectionRepository) Update(ctx context.Context, c *Collection, userID string) error {
|
||||
query := `
|
||||
UPDATE collections
|
||||
SET name = $2, description = $3, is_public = $4, context_enabled = $5, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND user_id = $6
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
c.ID, c.Name, c.Description, c.IsPublic, c.ContextEnabled,
|
||||
result, err := r.db.db.ExecContext(ctx, query,
|
||||
c.ID, c.Name, c.Description, c.IsPublic, c.ContextEnabled, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collections WHERE id = $1", id)
|
||||
return err
|
||||
func (r *CollectionRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM collections WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) AddItem(ctx context.Context, item *CollectionItem) error {
|
||||
@@ -135,7 +149,20 @@ func (r *CollectionRepository) AddItem(ctx context.Context, item *CollectionItem
|
||||
).Scan(&item.ID, &item.CreatedAt, &item.SortOrder)
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetItems(ctx context.Context, collectionID string) ([]CollectionItem, error) {
|
||||
func (r *CollectionRepository) GetItems(ctx context.Context, collectionID, userID string) ([]CollectionItem, error) {
|
||||
var ownerID string
|
||||
var isPublic bool
|
||||
err := r.db.db.QueryRowContext(ctx, "SELECT user_id, is_public FROM collections WHERE id = $1", collectionID).Scan(&ownerID, &isPublic)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ownerID != userID && !isPublic {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, collection_id, item_type, title, content, url, metadata, created_at, sort_order
|
||||
FROM collection_items
|
||||
@@ -168,13 +195,25 @@ func (r *CollectionRepository) GetItems(ctx context.Context, collectionID string
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) RemoveItem(ctx context.Context, itemID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collection_items WHERE id = $1", itemID)
|
||||
return err
|
||||
func (r *CollectionRepository) RemoveItem(ctx context.Context, itemID, userID string) error {
|
||||
query := `
|
||||
DELETE FROM collection_items
|
||||
WHERE id = $1
|
||||
AND collection_id IN (SELECT id FROM collections WHERE user_id = $2)
|
||||
`
|
||||
result, err := r.db.db.ExecContext(ctx, query, itemID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetCollectionContext(ctx context.Context, collectionID string) (string, error) {
|
||||
items, err := r.GetItems(ctx, collectionID)
|
||||
func (r *CollectionRepository) GetCollectionContext(ctx context.Context, collectionID, userID string) (string, error) {
|
||||
items, err := r.GetItems(ctx, collectionID, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerArtifactRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerArtifactRepo(db *sql.DB) *ComputerArtifactRepo {
|
||||
return &ComputerArtifactRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_artifacts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
task_id UUID NOT NULL,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255),
|
||||
content BYTEA,
|
||||
url TEXT,
|
||||
size BIGINT DEFAULT 0,
|
||||
mime_type VARCHAR(100),
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_task_id ON computer_artifacts(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_type ON computer_artifacts(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_created ON computer_artifacts(created_at DESC);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Create(ctx context.Context, artifact *computer.Artifact) error {
|
||||
metadataJSON, _ := json.Marshal(artifact.Metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_artifacts (id, task_id, type, name, content, url, size, mime_type, metadata, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
artifact.ID,
|
||||
artifact.TaskID,
|
||||
artifact.Type,
|
||||
artifact.Name,
|
||||
artifact.Content,
|
||||
artifact.URL,
|
||||
artifact.Size,
|
||||
artifact.MimeType,
|
||||
metadataJSON,
|
||||
artifact.CreatedAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByID(ctx context.Context, id string) (*computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, content, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var artifact computer.Artifact
|
||||
var content []byte
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&content,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
artifact.Content = content
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
return &artifact, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByTaskID(ctx context.Context, taskID string) ([]computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var artifacts []computer.Artifact
|
||||
|
||||
for rows.Next() {
|
||||
var artifact computer.Artifact
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByType(ctx context.Context, taskID, artifactType string) ([]computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1 AND type = $2
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID, artifactType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var artifacts []computer.Artifact
|
||||
|
||||
for rows.Next() {
|
||||
var artifact computer.Artifact
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetContent(ctx context.Context, id string) ([]byte, error) {
|
||||
query := `SELECT content FROM computer_artifacts WHERE id = $1`
|
||||
var content []byte
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(&content)
|
||||
return content, err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) UpdateURL(ctx context.Context, id, url string) error {
|
||||
query := `UPDATE computer_artifacts SET url = $1 WHERE id = $2`
|
||||
_, err := r.db.ExecContext(ctx, query, url, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_artifacts WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) DeleteByTaskID(ctx context.Context, taskID string) error {
|
||||
query := `DELETE FROM computer_artifacts WHERE task_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM computer_artifacts
|
||||
WHERE created_at < NOW() - INTERVAL '1 day' * $1
|
||||
`
|
||||
result, err := r.db.ExecContext(ctx, query, days)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetTotalSize(ctx context.Context, taskID string) (int64, error) {
|
||||
query := `SELECT COALESCE(SUM(size), 0) FROM computer_artifacts WHERE task_id = $1`
|
||||
var size int64
|
||||
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&size)
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Count(ctx context.Context, taskID string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_artifacts WHERE task_id = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
type ArtifactSummary struct {
|
||||
ID string `json:"id"`
|
||||
TaskID string `json:"taskId"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Size int64 `json:"size"`
|
||||
MimeType string `json:"mimeType"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetSummaries(ctx context.Context, taskID string) ([]ArtifactSummary, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var summaries []ArtifactSummary
|
||||
|
||||
for rows.Next() {
|
||||
var s ArtifactSummary
|
||||
var url, mimeType sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&s.ID,
|
||||
&s.TaskID,
|
||||
&s.Type,
|
||||
&s.Name,
|
||||
&url,
|
||||
&s.Size,
|
||||
&mimeType,
|
||||
&s.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
s.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
s.MimeType = mimeType.String
|
||||
}
|
||||
|
||||
summaries = append(summaries, s)
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerMemoryRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerMemoryRepo(db *sql.DB) *ComputerMemoryRepo {
|
||||
return &ComputerMemoryRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_memory (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
task_id UUID,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
value JSONB NOT NULL,
|
||||
type VARCHAR(50),
|
||||
tags TEXT[],
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_user_id ON computer_memory(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_task_id ON computer_memory(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_type ON computer_memory(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_expires ON computer_memory(expires_at) WHERE expires_at IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_key ON computer_memory(key);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Store(ctx context.Context, entry *computer.MemoryEntry) error {
|
||||
valueJSON, err := json.Marshal(entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_memory (id, user_id, task_id, key, value, type, tags, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
value = EXCLUDED.value,
|
||||
type = EXCLUDED.type,
|
||||
tags = EXCLUDED.tags,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
`
|
||||
|
||||
var taskID interface{}
|
||||
if entry.TaskID != "" {
|
||||
taskID = entry.TaskID
|
||||
}
|
||||
|
||||
_, err = r.db.ExecContext(ctx, query,
|
||||
entry.ID,
|
||||
entry.UserID,
|
||||
taskID,
|
||||
entry.Key,
|
||||
valueJSON,
|
||||
entry.Type,
|
||||
entry.Tags,
|
||||
entry.CreatedAt,
|
||||
entry.ExpiresAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByUser(ctx context.Context, userID string, limit int) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByTask(ctx context.Context, taskID string) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE task_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Search(ctx context.Context, userID, query string, limit int) ([]computer.MemoryEntry, error) {
|
||||
searchTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(searchTerms) == 0 {
|
||||
return r.GetByUser(ctx, userID, limit)
|
||||
}
|
||||
|
||||
likePatterns := make([]string, len(searchTerms))
|
||||
args := make([]interface{}, len(searchTerms)+2)
|
||||
args[0] = userID
|
||||
|
||||
for i, term := range searchTerms {
|
||||
likePatterns[i] = "%" + term + "%"
|
||||
args[i+1] = likePatterns[i]
|
||||
}
|
||||
args[len(args)-1] = limit
|
||||
|
||||
var conditions []string
|
||||
for i := range searchTerms {
|
||||
conditions = append(conditions, "(LOWER(key) LIKE $"+string(rune('2'+i))+" OR LOWER(value::text) LIKE $"+string(rune('2'+i))+")")
|
||||
}
|
||||
|
||||
sqlQuery := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (` + strings.Join(conditions, " OR ") + `)
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $` + string(rune('2'+len(searchTerms)))
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return r.GetByUser(ctx, userID, limit)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByType(ctx context.Context, userID, memType string, limit int) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1 AND type = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, memType, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByKey(ctx context.Context, userID, key string) (*computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1 AND key = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var entry computer.MemoryEntry
|
||||
var valueJSON []byte
|
||||
var taskID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var tags []string
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, userID, key).Scan(
|
||||
&entry.ID,
|
||||
&entry.UserID,
|
||||
&taskID,
|
||||
&entry.Key,
|
||||
&valueJSON,
|
||||
&entry.Type,
|
||||
&tags,
|
||||
&entry.CreatedAt,
|
||||
&expiresAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
entry.TaskID = taskID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
entry.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
entry.Tags = tags
|
||||
|
||||
json.Unmarshal(valueJSON, &entry.Value)
|
||||
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_memory WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteByUser(ctx context.Context, userID string) error {
|
||||
query := `DELETE FROM computer_memory WHERE user_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteByTask(ctx context.Context, taskID string) error {
|
||||
query := `DELETE FROM computer_memory WHERE task_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteExpired(ctx context.Context) (int64, error) {
|
||||
query := `DELETE FROM computer_memory WHERE expires_at IS NOT NULL AND expires_at < NOW()`
|
||||
result, err := r.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) scanEntries(rows *sql.Rows) ([]computer.MemoryEntry, error) {
|
||||
var entries []computer.MemoryEntry
|
||||
|
||||
for rows.Next() {
|
||||
var entry computer.MemoryEntry
|
||||
var valueJSON []byte
|
||||
var taskID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var tags []string
|
||||
|
||||
err := rows.Scan(
|
||||
&entry.ID,
|
||||
&entry.UserID,
|
||||
&taskID,
|
||||
&entry.Key,
|
||||
&valueJSON,
|
||||
&entry.Type,
|
||||
&tags,
|
||||
&entry.CreatedAt,
|
||||
&expiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
entry.TaskID = taskID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
entry.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
entry.Tags = tags
|
||||
|
||||
json.Unmarshal(valueJSON, &entry.Value)
|
||||
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Count(ctx context.Context, userID string) (int64, error) {
|
||||
query := `
|
||||
SELECT COUNT(*)
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) UpdateExpiry(ctx context.Context, id string, expiresAt time.Time) error {
|
||||
query := `UPDATE computer_memory SET expires_at = $1 WHERE id = $2`
|
||||
_, err := r.db.ExecContext(ctx, query, expiresAt, id)
|
||||
return err
|
||||
}
|
||||
@@ -1,411 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerTaskRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerTaskRepo(db *sql.DB) *ComputerTaskRepo {
|
||||
return &ComputerTaskRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_tasks (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
query TEXT NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
plan JSONB,
|
||||
sub_tasks JSONB,
|
||||
artifacts JSONB,
|
||||
memory JSONB,
|
||||
progress INT DEFAULT 0,
|
||||
message TEXT,
|
||||
error TEXT,
|
||||
schedule JSONB,
|
||||
next_run_at TIMESTAMPTZ,
|
||||
run_count INT DEFAULT 0,
|
||||
total_cost DECIMAL(10,6) DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
completed_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_user_id ON computer_tasks(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_status ON computer_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_next_run ON computer_tasks(next_run_at) WHERE next_run_at IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_created ON computer_tasks(created_at DESC);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Create(ctx context.Context, task *computer.ComputerTask) error {
|
||||
planJSON, _ := json.Marshal(task.Plan)
|
||||
subTasksJSON, _ := json.Marshal(task.SubTasks)
|
||||
artifactsJSON, _ := json.Marshal(task.Artifacts)
|
||||
memoryJSON, _ := json.Marshal(task.Memory)
|
||||
scheduleJSON, _ := json.Marshal(task.Schedule)
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_tasks (
|
||||
id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
task.ID,
|
||||
task.UserID,
|
||||
task.Query,
|
||||
task.Status,
|
||||
planJSON,
|
||||
subTasksJSON,
|
||||
artifactsJSON,
|
||||
memoryJSON,
|
||||
task.Progress,
|
||||
task.Message,
|
||||
task.Error,
|
||||
scheduleJSON,
|
||||
task.NextRunAt,
|
||||
task.RunCount,
|
||||
task.TotalCost,
|
||||
task.CreatedAt,
|
||||
task.UpdatedAt,
|
||||
task.CompletedAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Update(ctx context.Context, task *computer.ComputerTask) error {
|
||||
planJSON, _ := json.Marshal(task.Plan)
|
||||
subTasksJSON, _ := json.Marshal(task.SubTasks)
|
||||
artifactsJSON, _ := json.Marshal(task.Artifacts)
|
||||
memoryJSON, _ := json.Marshal(task.Memory)
|
||||
scheduleJSON, _ := json.Marshal(task.Schedule)
|
||||
|
||||
query := `
|
||||
UPDATE computer_tasks SET
|
||||
status = $1,
|
||||
plan = $2,
|
||||
sub_tasks = $3,
|
||||
artifacts = $4,
|
||||
memory = $5,
|
||||
progress = $6,
|
||||
message = $7,
|
||||
error = $8,
|
||||
schedule = $9,
|
||||
next_run_at = $10,
|
||||
run_count = $11,
|
||||
total_cost = $12,
|
||||
updated_at = $13,
|
||||
completed_at = $14
|
||||
WHERE id = $15
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
task.Status,
|
||||
planJSON,
|
||||
subTasksJSON,
|
||||
artifactsJSON,
|
||||
memoryJSON,
|
||||
task.Progress,
|
||||
task.Message,
|
||||
task.Error,
|
||||
scheduleJSON,
|
||||
task.NextRunAt,
|
||||
task.RunCount,
|
||||
task.TotalCost,
|
||||
time.Now(),
|
||||
task.CompletedAt,
|
||||
task.ID,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetByID(ctx context.Context, id string) (*computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []computer.ComputerTask
|
||||
|
||||
for rows.Next() {
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetScheduled(ctx context.Context) ([]computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE status = 'scheduled' AND schedule IS NOT NULL
|
||||
ORDER BY next_run_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []computer.ComputerTask
|
||||
|
||||
for rows.Next() {
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_tasks WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM computer_tasks
|
||||
WHERE created_at < NOW() - INTERVAL '%d days'
|
||||
AND status IN ('completed', 'failed', 'cancelled')
|
||||
`
|
||||
result, err := r.db.ExecContext(ctx, fmt.Sprintf(query, days))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_tasks WHERE user_id = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) CountByStatus(ctx context.Context, status string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_tasks WHERE status = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, status).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
@@ -99,20 +99,34 @@ func (r *FileRepository) GetByUserID(ctx context.Context, userID string, limit,
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (r *FileRepository) UpdateExtractedText(ctx context.Context, id, text string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE uploaded_files SET extracted_text = $2 WHERE id = $1",
|
||||
id, text,
|
||||
func (r *FileRepository) UpdateExtractedText(ctx context.Context, id, text, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE uploaded_files SET extracted_text = $2 WHERE id = $1 AND user_id = $3",
|
||||
id, text, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *FileRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM uploaded_files WHERE id = $1", id)
|
||||
return err
|
||||
func (r *FileRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM uploaded_files WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *FileRepository) GetByIDs(ctx context.Context, ids []string) ([]*UploadedFile, error) {
|
||||
func (r *FileRepository) GetByIDs(ctx context.Context, ids []string, userID string) ([]*UploadedFile, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -120,10 +134,10 @@ func (r *FileRepository) GetByIDs(ctx context.Context, ids []string) ([]*Uploade
|
||||
query := `
|
||||
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
|
||||
FROM uploaded_files
|
||||
WHERE id = ANY($1)
|
||||
WHERE id = ANY($1) AND user_id = $2
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, ids)
|
||||
rows, err := r.db.db.QueryContext(ctx, query, ids, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -147,17 +147,31 @@ func (r *MemoryRepository) GetContextForUser(ctx context.Context, userID string)
|
||||
return context, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) IncrementUseCount(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE user_memories SET use_count = use_count + 1, last_used = NOW() WHERE id = $1",
|
||||
id,
|
||||
func (r *MemoryRepository) IncrementUseCount(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE user_memories SET use_count = use_count + 1, last_used = NOW() WHERE id = $1 AND user_id = $2",
|
||||
id, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE id = $1", id)
|
||||
return err
|
||||
func (r *MemoryRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) DeleteByUserID(ctx context.Context, userID string) error {
|
||||
|
||||
@@ -182,27 +182,41 @@ func (r *PageRepository) GetByUserID(ctx context.Context, userID string, limit,
|
||||
return pagesList, nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) Update(ctx context.Context, p *pages.Page) error {
|
||||
func (r *PageRepository) Update(ctx context.Context, p *pages.Page, userID string) error {
|
||||
sectionsJSON, _ := json.Marshal(p.Sections)
|
||||
sourcesJSON, _ := json.Marshal(p.Sources)
|
||||
|
||||
query := `
|
||||
UPDATE pages
|
||||
SET title = $2, subtitle = $3, sections = $4, sources = $5, thumbnail = $6, is_public = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND user_id = $8
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
p.ID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic,
|
||||
result, err := r.db.db.ExecContext(ctx, query,
|
||||
p.ID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) SetShareID(ctx context.Context, pageID, shareID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE pages SET share_id = $2, is_public = true WHERE id = $1",
|
||||
pageID, shareID,
|
||||
func (r *PageRepository) SetShareID(ctx context.Context, pageID, shareID, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE pages SET share_id = $2, is_public = true WHERE id = $1 AND user_id = $3",
|
||||
pageID, shareID, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) IncrementViewCount(ctx context.Context, id string) error {
|
||||
@@ -213,7 +227,14 @@ func (r *PageRepository) IncrementViewCount(ctx context.Context, id string) erro
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PageRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM pages WHERE id = $1", id)
|
||||
return err
|
||||
func (r *PageRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM pages WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func (r *SpaceRepository) GetByUserID(ctx context.Context, userID string) ([]*Sp
|
||||
return spaces, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) Update(ctx context.Context, s *Space) error {
|
||||
func (r *SpaceRepository) Update(ctx context.Context, s *Space, userID string) error {
|
||||
settingsJSON, _ := json.Marshal(s.Settings)
|
||||
|
||||
query := `
|
||||
@@ -147,17 +147,31 @@ func (r *SpaceRepository) Update(ctx context.Context, s *Space) error {
|
||||
SET name = $2, description = $3, icon = $4, color = $5,
|
||||
custom_instructions = $6, default_focus_mode = $7, default_model = $8,
|
||||
is_public = $9, settings = $10, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND user_id = $11
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
result, err := r.db.db.ExecContext(ctx, query,
|
||||
s.ID, s.Name, s.Description, s.Icon, s.Color,
|
||||
s.CustomInstructions, s.DefaultFocusMode, s.DefaultModel,
|
||||
s.IsPublic, settingsJSON,
|
||||
s.IsPublic, settingsJSON, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM spaces WHERE id = $1", id)
|
||||
return err
|
||||
func (r *SpaceRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM spaces WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,9 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("resource not found")
|
||||
ErrForbidden = errors.New("access denied")
|
||||
)
|
||||
|
||||
type Thread struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
@@ -176,30 +182,63 @@ func (r *ThreadRepository) GetByUserID(ctx context.Context, userID string, limit
|
||||
return threads, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) Update(ctx context.Context, t *Thread) error {
|
||||
func (r *ThreadRepository) Update(ctx context.Context, t *Thread, userID string) error {
|
||||
query := `
|
||||
UPDATE threads
|
||||
SET title = $2, focus_mode = $3, is_public = $4, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND user_id = $5
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query, t.ID, t.Title, t.FocusMode, t.IsPublic)
|
||||
return err
|
||||
result, err := r.db.db.ExecContext(ctx, query, t.ID, t.Title, t.FocusMode, t.IsPublic, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) SetShareID(ctx context.Context, threadID, shareID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET share_id = $2, is_public = true WHERE id = $1",
|
||||
threadID, shareID,
|
||||
func (r *ThreadRepository) SetShareID(ctx context.Context, threadID, shareID, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET share_id = $2, is_public = true WHERE id = $1 AND user_id = $3",
|
||||
threadID, shareID, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM threads WHERE id = $1", id)
|
||||
return err
|
||||
func (r *ThreadRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM threads WHERE id = $1 AND user_id = $2", id, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage) error {
|
||||
func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage, userID string) error {
|
||||
var ownerID string
|
||||
err := r.db.db.QueryRowContext(ctx, "SELECT user_id FROM threads WHERE id = $1", msg.ThreadID).Scan(&ownerID)
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ownerID != userID {
|
||||
return ErrForbidden
|
||||
}
|
||||
|
||||
sourcesJSON, _ := json.Marshal(msg.Sources)
|
||||
widgetsJSON, _ := json.Marshal(msg.Widgets)
|
||||
relatedJSON, _ := json.Marshal(msg.RelatedQuestions)
|
||||
@@ -209,7 +248,7 @@ func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage) e
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
err := r.db.db.QueryRowContext(ctx, query,
|
||||
err = r.db.db.QueryRowContext(ctx, query,
|
||||
msg.ThreadID, msg.Role, msg.Content, sourcesJSON, widgetsJSON, relatedJSON, msg.Model, msg.TokensUsed,
|
||||
).Scan(&msg.ID, &msg.CreatedAt)
|
||||
|
||||
@@ -220,7 +259,20 @@ func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage) e
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GetMessages(ctx context.Context, threadID string, limit, offset int) ([]ThreadMessage, error) {
|
||||
func (r *ThreadRepository) GetMessages(ctx context.Context, threadID, userID string, limit, offset int) ([]ThreadMessage, error) {
|
||||
var ownerID string
|
||||
var isPublic bool
|
||||
err := r.db.db.QueryRowContext(ctx, "SELECT user_id, is_public FROM threads WHERE id = $1", threadID).Scan(&ownerID, &isPublic)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ownerID != userID && !isPublic {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, thread_id, role, content, sources, widgets, related_questions, model, tokens_used, created_at
|
||||
FROM thread_messages
|
||||
@@ -257,14 +309,21 @@ func (r *ThreadRepository) GetMessages(ctx context.Context, threadID string, lim
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GenerateTitle(ctx context.Context, threadID, firstMessage string) error {
|
||||
func (r *ThreadRepository) GenerateTitle(ctx context.Context, threadID, firstMessage, userID string) error {
|
||||
title := firstMessage
|
||||
if len(title) > 100 {
|
||||
title = title[:97] + "..."
|
||||
}
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET title = $2 WHERE id = $1",
|
||||
threadID, title,
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET title = $2 WHERE id = $1 AND user_id = $3",
|
||||
threadID, title, userID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user