feat: Go backend, enhanced search, new widgets, Docker deploy

Major changes:
- Add Go backend (backend/) with microservices architecture
- Enhanced master-agents-svc: reranker, content-classifier, stealth-crawler,
  proxy-manager, media-search, fastClassifier, language detection
- New web-svc widgets: KnowledgeCard, ProductCard, ProfileCard, VideoCard,
  UnifiedCard, CardGallery, InlineImageGallery, SourcesPanel, RelatedQuestions
- Improved discover-svc with discover-db integration
- Docker deployment improvements (Caddyfile, vendor.sh, BUILD.md)
- Library-svc: project_id schema migration
- Remove deprecated finance-svc and travel-svc
- Localization improvements across services

Made-with: Cursor
This commit is contained in:
home
2026-02-27 04:15:32 +03:00
parent 328d968f3f
commit 06fe57c765
285 changed files with 53132 additions and 1871 deletions

View File

@@ -0,0 +1,233 @@
package agent
import (
"context"
"encoding/json"
"regexp"
"strings"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/prompts"
)
type ClassificationResult struct {
StandaloneFollowUp string `json:"standaloneFollowUp"`
SkipSearch bool `json:"skipSearch"`
Topics []string `json:"topics,omitempty"`
QueryType string `json:"queryType,omitempty"`
Engines []string `json:"engines,omitempty"`
}
func classify(ctx context.Context, client llm.Client, query string, history []llm.Message, locale, detectedLang string) (*ClassificationResult, error) {
prompt := prompts.GetClassifierPrompt(locale, detectedLang)
historyStr := formatHistory(history)
userContent := "<conversation>\n" + historyStr + "\nUser: " + query + "\n</conversation>"
messages := []llm.Message{
{Role: llm.RoleSystem, Content: prompt},
{Role: llm.RoleUser, Content: userContent},
}
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{MaxTokens: 1024},
})
if err != nil {
return nil, err
}
jsonMatch := regexp.MustCompile(`\{[\s\S]*\}`).FindString(response)
if jsonMatch == "" {
return &ClassificationResult{
StandaloneFollowUp: query,
SkipSearch: false,
}, nil
}
var result ClassificationResult
if err := json.Unmarshal([]byte(jsonMatch), &result); err != nil {
return &ClassificationResult{
StandaloneFollowUp: query,
SkipSearch: false,
}, nil
}
if result.StandaloneFollowUp == "" {
result.StandaloneFollowUp = query
}
return &result, nil
}
func fastClassify(query string, history []llm.Message) *ClassificationResult {
queryLower := strings.ToLower(query)
skipPatterns := []string{
"привет", "как дела", "спасибо", "пока",
"hello", "hi", "thanks", "bye",
"объясни", "расскажи подробнее", "что ты имеешь",
}
skipSearch := false
for _, p := range skipPatterns {
if strings.Contains(queryLower, p) && len(query) < 50 {
skipSearch = true
break
}
}
standalone := query
if len(history) > 0 {
pronouns := []string{
"это", "этот", "эта", "эти",
"он", "она", "оно", "они",
"it", "this", "that", "they", "them",
}
hasPronouns := false
for _, p := range pronouns {
if strings.Contains(queryLower, p+" ") || strings.HasPrefix(queryLower, p+" ") {
hasPronouns = true
break
}
}
if hasPronouns && len(history) >= 2 {
lastAssistant := ""
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == llm.RoleAssistant {
lastAssistant = history[i].Content
break
}
}
if lastAssistant != "" {
topics := extractTopics(lastAssistant)
if len(topics) > 0 {
standalone = query + " (контекст: " + strings.Join(topics, ", ") + ")"
}
}
}
}
engines := detectEngines(queryLower)
return &ClassificationResult{
StandaloneFollowUp: standalone,
SkipSearch: skipSearch,
Engines: engines,
}
}
func generateSearchQueries(query string) []string {
queries := []string{query}
if len(query) > 100 {
words := strings.Fields(query)
if len(words) > 5 {
queries = append(queries, strings.Join(words[:5], " "))
}
}
keywordPatterns := []string{
"как", "что такое", "где", "когда", "почему", "кто",
"how", "what is", "where", "when", "why", "who",
}
for _, p := range keywordPatterns {
if strings.HasPrefix(strings.ToLower(query), p) {
withoutPrefix := strings.TrimPrefix(strings.ToLower(query), p)
withoutPrefix = strings.TrimSpace(withoutPrefix)
if len(withoutPrefix) > 10 {
queries = append(queries, withoutPrefix)
}
break
}
}
if len(queries) > 3 {
queries = queries[:3]
}
return queries
}
func detectEngines(query string) []string {
engines := []string{"google", "duckduckgo"}
if strings.Contains(query, "новости") || strings.Contains(query, "news") {
engines = append(engines, "google_news")
}
if strings.Contains(query, "видео") || strings.Contains(query, "video") {
engines = append(engines, "youtube")
}
if strings.Contains(query, "товар") || strings.Contains(query, "купить") ||
strings.Contains(query, "цена") || strings.Contains(query, "price") {
engines = append(engines, "google_shopping")
}
return engines
}
func extractTopics(text string) []string {
words := strings.Fields(text)
if len(words) > 50 {
words = words[:50]
}
topics := make([]string, 0)
for _, w := range words {
if len(w) > 5 && len(w) < 20 {
r := []rune(w)
if len(r) > 0 && ((r[0] >= 'A' && r[0] <= 'Z') || (r[0] >= 'А' && r[0] <= 'Я')) {
topics = append(topics, w)
if len(topics) >= 3 {
break
}
}
}
}
return topics
}
func formatHistory(messages []llm.Message) string {
var sb strings.Builder
for _, m := range messages {
role := "User"
if m.Role == llm.RoleAssistant {
role = "Assistant"
}
sb.WriteString(role)
sb.WriteString(": ")
content := m.Content
if len(content) > 500 {
content = content[:500] + "..."
}
sb.WriteString(content)
sb.WriteString("\n")
}
return sb.String()
}
func detectLanguage(text string) string {
cyrillicCount := 0
latinCount := 0
for _, r := range text {
if r >= 'а' && r <= 'я' || r >= 'А' && r <= 'Я' {
cyrillicCount++
} else if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' {
latinCount++
}
}
if cyrillicCount > latinCount {
return "ru"
}
return "en"
}

View File

@@ -0,0 +1,543 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/search"
"github.com/gooseek/backend/internal/session"
"github.com/gooseek/backend/internal/types"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
)
type DeepResearchConfig struct {
LLM llm.Client
SearchClient *search.SearXNGClient
FocusMode FocusMode
Locale string
MaxSearchQueries int
MaxSources int
MaxIterations int
Timeout time.Duration
}
type DeepResearchResult struct {
FinalReport string
Sources []types.Chunk
SubQueries []SubQuery
Insights []string
FollowUpQueries []string
TotalSearches int
TotalSources int
Duration time.Duration
}
type SubQuery struct {
Query string
Purpose string
Status string
Results []types.Chunk
Insights []string
}
type DeepResearcher struct {
cfg DeepResearchConfig
sess *session.Session
mu sync.Mutex
allSources []types.Chunk
seenURLs map[string]bool
subQueries []SubQuery
insights []string
searchCount int
startTime time.Time
}
func NewDeepResearcher(cfg DeepResearchConfig, sess *session.Session) *DeepResearcher {
if cfg.MaxSearchQueries == 0 {
cfg.MaxSearchQueries = 30
}
if cfg.MaxSources == 0 {
cfg.MaxSources = 100
}
if cfg.MaxIterations == 0 {
cfg.MaxIterations = 5
}
if cfg.Timeout == 0 {
cfg.Timeout = 5 * time.Minute
}
return &DeepResearcher{
cfg: cfg,
sess: sess,
seenURLs: make(map[string]bool),
allSources: make([]types.Chunk, 0),
subQueries: make([]SubQuery, 0),
insights: make([]string, 0),
startTime: time.Now(),
}
}
func (dr *DeepResearcher) Research(ctx context.Context, query string) (*DeepResearchResult, error) {
ctx, cancel := context.WithTimeout(ctx, dr.cfg.Timeout)
defer cancel()
researchBlockID := uuid.New().String()
dr.sess.EmitBlock(&types.Block{
ID: researchBlockID,
Type: types.BlockTypeResearch,
Data: types.ResearchData{
SubSteps: []types.ResearchSubStep{},
},
})
subQueries, err := dr.planResearch(ctx, query)
if err != nil {
return nil, fmt.Errorf("planning failed: %w", err)
}
dr.updateResearchStatus(researchBlockID, "researching", fmt.Sprintf("Executing %d sub-queries", len(subQueries)))
for i := 0; i < dr.cfg.MaxIterations && dr.searchCount < dr.cfg.MaxSearchQueries; i++ {
if err := dr.executeIteration(ctx, i, researchBlockID); err != nil {
if ctx.Err() != nil {
break
}
}
if dr.hasEnoughData() {
break
}
newQueries, err := dr.generateFollowUpQueries(ctx, query)
if err != nil || len(newQueries) == 0 {
break
}
for _, q := range newQueries {
dr.mu.Lock()
dr.subQueries = append(dr.subQueries, SubQuery{
Query: q.Query,
Purpose: q.Purpose,
Status: "pending",
})
dr.mu.Unlock()
}
}
dr.updateResearchStatus(researchBlockID, "synthesizing", "Analyzing findings")
insights, err := dr.synthesizeInsights(ctx, query)
if err != nil {
insights = dr.insights
}
dr.updateResearchStatus(researchBlockID, "writing", "Generating report")
report, err := dr.generateFinalReport(ctx, query, insights)
if err != nil {
return nil, fmt.Errorf("report generation failed: %w", err)
}
followUp, _ := dr.generateFollowUpSuggestions(ctx, query, report)
dr.updateResearchStatus(researchBlockID, "complete", "Research complete")
return &DeepResearchResult{
FinalReport: report,
Sources: dr.allSources,
SubQueries: dr.subQueries,
Insights: insights,
FollowUpQueries: followUp,
TotalSearches: dr.searchCount,
TotalSources: len(dr.allSources),
Duration: time.Since(dr.startTime),
}, nil
}
func (dr *DeepResearcher) planResearch(ctx context.Context, query string) ([]SubQuery, error) {
prompt := fmt.Sprintf(`Analyze this research query and break it into 3-5 sub-queries for comprehensive research.
Query: %s
For each sub-query, specify:
1. The search query (optimized for search engines)
2. The purpose (what aspect it addresses)
Respond in this exact format:
QUERY: [search query]
PURPOSE: [what this addresses]
QUERY: [search query]
PURPOSE: [what this addresses]
...
Be specific and actionable. Focus on different aspects: definitions, current state, history, expert opinions, data/statistics, controversies, future trends.`, query)
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return dr.generateDefaultSubQueries(query), nil
}
subQueries := dr.parseSubQueries(result)
if len(subQueries) == 0 {
subQueries = dr.generateDefaultSubQueries(query)
}
dr.mu.Lock()
dr.subQueries = subQueries
dr.mu.Unlock()
return subQueries, nil
}
func (dr *DeepResearcher) parseSubQueries(text string) []SubQuery {
var queries []SubQuery
lines := strings.Split(text, "\n")
var currentQuery, currentPurpose string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "QUERY:") {
if currentQuery != "" && currentPurpose != "" {
queries = append(queries, SubQuery{
Query: currentQuery,
Purpose: currentPurpose,
Status: "pending",
})
}
currentQuery = strings.TrimSpace(strings.TrimPrefix(line, "QUERY:"))
currentPurpose = ""
} else if strings.HasPrefix(line, "PURPOSE:") {
currentPurpose = strings.TrimSpace(strings.TrimPrefix(line, "PURPOSE:"))
}
}
if currentQuery != "" && currentPurpose != "" {
queries = append(queries, SubQuery{
Query: currentQuery,
Purpose: currentPurpose,
Status: "pending",
})
}
return queries
}
func (dr *DeepResearcher) generateDefaultSubQueries(query string) []SubQuery {
return []SubQuery{
{Query: query, Purpose: "Main query", Status: "pending"},
{Query: query + " definition explained", Purpose: "Definitions and basics", Status: "pending"},
{Query: query + " latest news 2026", Purpose: "Current developments", Status: "pending"},
{Query: query + " expert analysis", Purpose: "Expert opinions", Status: "pending"},
{Query: query + " statistics data research", Purpose: "Data and evidence", Status: "pending"},
}
}
func (dr *DeepResearcher) executeIteration(ctx context.Context, iteration int, blockID string) error {
dr.mu.Lock()
pendingQueries := make([]int, 0)
for i, sq := range dr.subQueries {
if sq.Status == "pending" {
pendingQueries = append(pendingQueries, i)
}
}
dr.mu.Unlock()
if len(pendingQueries) == 0 {
return nil
}
batchSize := 3
if len(pendingQueries) < batchSize {
batchSize = len(pendingQueries)
}
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(batchSize)
for _, idx := range pendingQueries[:batchSize] {
idx := idx
g.Go(func() error {
return dr.executeSubQuery(gctx, idx, blockID)
})
}
return g.Wait()
}
func (dr *DeepResearcher) executeSubQuery(ctx context.Context, idx int, blockID string) error {
dr.mu.Lock()
if idx >= len(dr.subQueries) {
dr.mu.Unlock()
return nil
}
sq := &dr.subQueries[idx]
sq.Status = "searching"
query := sq.Query
dr.searchCount++
dr.mu.Unlock()
dr.updateResearchStatus(blockID, "researching", fmt.Sprintf("Searching: %s", truncate(query, 50)))
enhancedQuery := EnhanceQueryForFocusMode(query, dr.cfg.FocusMode)
results, err := dr.cfg.SearchClient.Search(ctx, enhancedQuery, &search.SearchOptions{
Engines: dr.cfg.FocusMode.GetSearchEngines(),
Categories: FocusModeConfigs[dr.cfg.FocusMode].Categories,
PageNo: 1,
})
if err != nil {
dr.mu.Lock()
sq.Status = "failed"
dr.mu.Unlock()
return err
}
chunks := make([]types.Chunk, 0)
for _, r := range results.Results {
dr.mu.Lock()
if dr.seenURLs[r.URL] {
dr.mu.Unlock()
continue
}
dr.seenURLs[r.URL] = true
dr.mu.Unlock()
chunk := r.ToChunk()
chunks = append(chunks, chunk)
if len(chunks) >= 10 {
break
}
}
dr.mu.Lock()
sq.Results = chunks
sq.Status = "complete"
dr.allSources = append(dr.allSources, chunks...)
dr.mu.Unlock()
return nil
}
func (dr *DeepResearcher) generateFollowUpQueries(ctx context.Context, originalQuery string) ([]SubQuery, error) {
if dr.searchCount >= dr.cfg.MaxSearchQueries-5 {
return nil, nil
}
var sourceSummary strings.Builder
dr.mu.Lock()
for i, s := range dr.allSources {
if i >= 20 {
break
}
sourceSummary.WriteString(fmt.Sprintf("- %s: %s\n", s.Metadata["title"], truncate(s.Content, 100)))
}
dr.mu.Unlock()
prompt := fmt.Sprintf(`Based on the original query and sources found so far, suggest 2-3 follow-up queries to deepen the research.
Original query: %s
Sources found so far:
%s
What aspects are missing? What would provide more comprehensive coverage?
Respond with queries in format:
QUERY: [query]
PURPOSE: [what gap it fills]`, originalQuery, sourceSummary.String())
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
return dr.parseSubQueries(result), nil
}
func (dr *DeepResearcher) synthesizeInsights(ctx context.Context, query string) ([]string, error) {
var sourcesText strings.Builder
dr.mu.Lock()
for i, s := range dr.allSources {
if i >= 30 {
break
}
sourcesText.WriteString(fmt.Sprintf("[%d] %s\n%s\n\n", i+1, s.Metadata["title"], truncate(s.Content, 300)))
}
dr.mu.Unlock()
prompt := fmt.Sprintf(`Analyze these sources and extract 5-7 key insights for the query: %s
Sources:
%s
Provide insights as bullet points, each starting with a key finding.
Focus on: main conclusions, patterns, contradictions, expert consensus, data points.`, query, sourcesText.String())
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
insights := make([]string, 0)
for _, line := range strings.Split(result, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "•") || strings.HasPrefix(line, "*") {
insights = append(insights, strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "•"), "*"))
}
}
dr.mu.Lock()
dr.insights = insights
dr.mu.Unlock()
return insights, nil
}
func (dr *DeepResearcher) generateFinalReport(ctx context.Context, query string, insights []string) (string, error) {
var sourcesText strings.Builder
dr.mu.Lock()
sources := dr.allSources
dr.mu.Unlock()
for i, s := range sources {
if i >= 50 {
break
}
sourcesText.WriteString(fmt.Sprintf("[%d] %s (%s)\n%s\n\n", i+1, s.Metadata["title"], s.Metadata["url"], truncate(s.Content, 400)))
}
insightsText := strings.Join(insights, "\n- ")
focusCfg := FocusModeConfigs[dr.cfg.FocusMode]
locale := dr.cfg.Locale
if locale == "" {
locale = "en"
}
langInstruction := ""
if locale == "ru" {
langInstruction = "Write the report in Russian."
}
prompt := fmt.Sprintf(`%s
Write a comprehensive research report answering: %s
Key insights discovered:
- %s
Sources (cite using [1], [2], etc.):
%s
Structure your report with:
1. Executive Summary (2-3 sentences)
2. Key Findings (organized by theme)
3. Analysis and Discussion
4. Conclusions
%s
Use citations [1], [2], etc. throughout.
Be thorough but concise. Focus on actionable information.`, focusCfg.SystemPrompt, query, insightsText, sourcesText.String(), langInstruction)
stream, err := dr.cfg.LLM.StreamText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return "", err
}
var report strings.Builder
textBlockID := uuid.New().String()
dr.sess.EmitBlock(&types.Block{
ID: textBlockID,
Type: types.BlockTypeText,
Data: "",
})
for chunk := range stream {
report.WriteString(chunk.ContentChunk)
dr.sess.EmitTextChunk(textBlockID, chunk.ContentChunk)
}
return report.String(), nil
}
func (dr *DeepResearcher) generateFollowUpSuggestions(ctx context.Context, query, report string) ([]string, error) {
prompt := fmt.Sprintf(`Based on this research query and report, suggest 3-4 follow-up questions the user might want to explore:
Query: %s
Report summary: %s
Provide follow-up questions that:
1. Go deeper into specific aspects
2. Explore related topics
3. Address practical applications
4. Consider alternative perspectives
Format as simple questions, one per line.`, query, truncate(report, 1000))
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
suggestions := make([]string, 0)
for _, line := range strings.Split(result, "\n") {
line = strings.TrimSpace(line)
if line != "" && (strings.Contains(line, "?") || len(line) > 20) {
line = strings.TrimPrefix(line, "- ")
line = strings.TrimPrefix(line, "• ")
line = strings.TrimLeft(line, "0123456789. ")
if line != "" {
suggestions = append(suggestions, line)
}
}
}
if len(suggestions) > 4 {
suggestions = suggestions[:4]
}
return suggestions, nil
}
func (dr *DeepResearcher) updateResearchStatus(blockID, status, message string) {
dr.sess.UpdateBlock(blockID, []session.Patch{
{Op: "replace", Path: "/data/status", Value: status},
{Op: "replace", Path: "/data/message", Value: message},
})
}
func (dr *DeepResearcher) hasEnoughData() bool {
dr.mu.Lock()
defer dr.mu.Unlock()
return len(dr.allSources) >= dr.cfg.MaxSources
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func RunDeepResearch(ctx context.Context, sess *session.Session, query string, cfg DeepResearchConfig) (*DeepResearchResult, error) {
researcher := NewDeepResearcher(cfg, sess)
return researcher.Research(ctx, query)
}

View File

@@ -0,0 +1,293 @@
package agent
import (
"strings"
)
type FocusMode string
const (
FocusModeAll FocusMode = "all"
FocusModeAcademic FocusMode = "academic"
FocusModeWriting FocusMode = "writing"
FocusModeYouTube FocusMode = "youtube"
FocusModeReddit FocusMode = "reddit"
FocusModeCode FocusMode = "code"
FocusModeNews FocusMode = "news"
FocusModeImages FocusMode = "images"
FocusModeMath FocusMode = "math"
FocusModeFinance FocusMode = "finance"
)
type FocusModeConfig struct {
Mode FocusMode
Engines []string
Categories []string
SystemPrompt string
SearchQueryPrefix string
MaxSources int
RequiresCitation bool
AllowScraping bool
}
var FocusModeConfigs = map[FocusMode]FocusModeConfig{
FocusModeAll: {
Mode: FocusModeAll,
Engines: []string{"google", "bing", "duckduckgo"},
Categories: []string{"general"},
MaxSources: 15,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are a helpful AI assistant that provides comprehensive answers based on web search results.
Always cite your sources using [1], [2], etc. format.
Provide balanced, accurate information from multiple perspectives.`,
},
FocusModeAcademic: {
Mode: FocusModeAcademic,
Engines: []string{"google scholar", "arxiv", "pubmed", "semantic scholar"},
Categories: []string{"science"},
SearchQueryPrefix: "research paper",
MaxSources: 20,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are an academic research assistant specializing in scholarly sources.
Focus on peer-reviewed papers, academic journals, and reputable research institutions.
Always cite sources in academic format with [1], [2], etc.
Distinguish between primary research, meta-analyses, and review articles.
Mention publication dates, authors, and journals when available.
Be precise about confidence levels and note when findings are preliminary or contested.`,
},
FocusModeWriting: {
Mode: FocusModeWriting,
Engines: []string{"google"},
Categories: []string{"general"},
MaxSources: 5,
RequiresCitation: false,
AllowScraping: false,
SystemPrompt: `You are a creative writing assistant.
Help with drafting, editing, and improving written content.
Provide suggestions for style, tone, structure, and clarity.
Offer multiple variations when appropriate.
Focus on the user's voice and intent rather than web search results.`,
},
FocusModeYouTube: {
Mode: FocusModeYouTube,
Engines: []string{"youtube"},
Categories: []string{"videos"},
SearchQueryPrefix: "site:youtube.com",
MaxSources: 10,
RequiresCitation: true,
AllowScraping: false,
SystemPrompt: `You are a video content assistant focused on YouTube.
Summarize video content, recommend relevant videos, and help find tutorials.
Mention video titles, channels, and approximate timestamps when relevant.
Note view counts and upload dates to indicate video popularity and relevance.`,
},
FocusModeReddit: {
Mode: FocusModeReddit,
Engines: []string{"reddit"},
Categories: []string{"social media"},
SearchQueryPrefix: "site:reddit.com",
MaxSources: 15,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are an assistant that specializes in Reddit discussions and community knowledge.
Focus on highly upvoted comments and posts from relevant subreddits.
Note the subreddit source, upvote counts, and community consensus.
Distinguish between personal opinions, experiences, and factual claims.
Be aware of potential biases in specific communities.`,
},
FocusModeCode: {
Mode: FocusModeCode,
Engines: []string{"google", "github", "stackoverflow"},
Categories: []string{"it"},
SearchQueryPrefix: "",
MaxSources: 10,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are a programming assistant focused on code, documentation, and technical solutions.
Provide working code examples with explanations.
Reference official documentation, Stack Overflow answers, and GitHub repositories.
Mention library versions and compatibility considerations.
Follow best practices and coding standards for the relevant language/framework.
Include error handling and edge cases in code examples.`,
},
FocusModeNews: {
Mode: FocusModeNews,
Engines: []string{"google news", "bing news"},
Categories: []string{"news"},
MaxSources: 12,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are a news assistant that provides current events information.
Focus on recent, verified news from reputable sources.
Distinguish between breaking news, analysis, and opinion pieces.
Note publication dates and source credibility.
Present multiple perspectives on controversial topics.`,
},
FocusModeImages: {
Mode: FocusModeImages,
Engines: []string{"google images", "bing images"},
Categories: []string{"images"},
MaxSources: 20,
RequiresCitation: true,
AllowScraping: false,
SystemPrompt: `You are an image search assistant.
Help find relevant images, describe image sources, and provide context.
Note image sources, licenses, and quality when relevant.`,
},
FocusModeMath: {
Mode: FocusModeMath,
Engines: []string{"wolfram alpha", "google"},
Categories: []string{"science"},
MaxSources: 5,
RequiresCitation: true,
AllowScraping: false,
SystemPrompt: `You are a mathematical problem-solving assistant.
Provide step-by-step solutions with clear explanations.
Use proper mathematical notation and formatting.
Show your work and explain the reasoning behind each step.
Mention relevant theorems, formulas, and mathematical concepts.
Verify your calculations and provide alternative solution methods when applicable.`,
},
FocusModeFinance: {
Mode: FocusModeFinance,
Engines: []string{"google", "google finance", "yahoo finance"},
Categories: []string{"news"},
SearchQueryPrefix: "stock market finance",
MaxSources: 10,
RequiresCitation: true,
AllowScraping: true,
SystemPrompt: `You are a financial information assistant.
Provide accurate financial data, market analysis, and investment information.
Note that you cannot provide personalized financial advice.
Cite data sources and note when data may be delayed or historical.
Include relevant disclaimers about investment risks.
Reference SEC filings, analyst reports, and official company statements.`,
},
}
func GetFocusModeConfig(mode string) FocusModeConfig {
fm := FocusMode(strings.ToLower(mode))
if cfg, ok := FocusModeConfigs[fm]; ok {
return cfg
}
return FocusModeConfigs[FocusModeAll]
}
func DetectFocusMode(query string) FocusMode {
queryLower := strings.ToLower(query)
academicKeywords := []string{
"research", "paper", "study", "journal", "scientific", "academic",
"peer-reviewed", "citation", "исследование", "научн", "статья",
"публикация", "диссертация",
}
for _, kw := range academicKeywords {
if strings.Contains(queryLower, kw) {
return FocusModeAcademic
}
}
codeKeywords := []string{
"code", "programming", "function", "error", "bug", "api",
"library", "framework", "syntax", "compile", "debug",
"код", "программ", "функция", "ошибка", "библиотека",
"golang", "python", "javascript", "typescript", "react", "vue",
"docker", "kubernetes", "sql", "database", "git",
}
for _, kw := range codeKeywords {
if strings.Contains(queryLower, kw) {
return FocusModeCode
}
}
if strings.Contains(queryLower, "youtube") ||
strings.Contains(queryLower, "video tutorial") ||
strings.Contains(queryLower, "видео") {
return FocusModeYouTube
}
if strings.Contains(queryLower, "reddit") ||
strings.Contains(queryLower, "subreddit") ||
strings.Contains(queryLower, "/r/") {
return FocusModeReddit
}
mathKeywords := []string{
"calculate", "solve", "equation", "integral", "derivative",
"formula", "theorem", "proof", "вычисл", "решить", "уравнение",
"интеграл", "производная", "формула", "теорема",
}
for _, kw := range mathKeywords {
if strings.Contains(queryLower, kw) {
return FocusModeMath
}
}
financeKeywords := []string{
"stock", "market", "invest", "price", "trading", "finance",
"акци", "рынок", "инвест", "биржа", "котировк", "финанс",
"etf", "dividend", "portfolio",
}
for _, kw := range financeKeywords {
if strings.Contains(queryLower, kw) {
return FocusModeFinance
}
}
newsKeywords := []string{
"news", "today", "latest", "breaking", "current events",
"новост", "сегодня", "последн", "актуальн",
}
for _, kw := range newsKeywords {
if strings.Contains(queryLower, kw) {
return FocusModeNews
}
}
return FocusModeAll
}
func (f FocusMode) GetSearchEngines() []string {
if cfg, ok := FocusModeConfigs[f]; ok {
return cfg.Engines
}
return FocusModeConfigs[FocusModeAll].Engines
}
func (f FocusMode) GetSystemPrompt() string {
if cfg, ok := FocusModeConfigs[f]; ok {
return cfg.SystemPrompt
}
return FocusModeConfigs[FocusModeAll].SystemPrompt
}
func (f FocusMode) GetMaxSources() int {
if cfg, ok := FocusModeConfigs[f]; ok {
return cfg.MaxSources
}
return 15
}
func (f FocusMode) RequiresCitation() bool {
if cfg, ok := FocusModeConfigs[f]; ok {
return cfg.RequiresCitation
}
return true
}
func (f FocusMode) AllowsScraping() bool {
if cfg, ok := FocusModeConfigs[f]; ok {
return cfg.AllowScraping
}
return true
}
func EnhanceQueryForFocusMode(query string, mode FocusMode) string {
cfg := FocusModeConfigs[mode]
if cfg.SearchQueryPrefix != "" {
return cfg.SearchQueryPrefix + " " + query
}
return query
}

View File

@@ -0,0 +1,950 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/prompts"
"github.com/gooseek/backend/internal/search"
"github.com/gooseek/backend/internal/session"
"github.com/gooseek/backend/internal/types"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
)
type Mode string
const (
ModeSpeed Mode = "speed"
ModeBalanced Mode = "balanced"
ModeQuality Mode = "quality"
)
type OrchestratorConfig struct {
LLM llm.Client
SearchClient *search.SearXNGClient
Mode Mode
FocusMode FocusMode
Sources []string
FileIDs []string
FileContext string
CollectionID string
CollectionContext string
SystemInstructions string
Locale string
MemoryContext string
UserMemory string
AnswerMode string
ResponsePrefs *ResponsePrefs
LearningMode bool
EnableDeepResearch bool
EnableClarifying bool
DiscoverSvcURL string
Crawl4AIURL string
CollectionSvcURL string
FileSvcURL string
}
type DigestResponse struct {
SummaryRu string `json:"summaryRu"`
Citations []DigestCitation `json:"citations"`
FollowUp []string `json:"followUp"`
SourcesCount int `json:"sourcesCount"`
ClusterTitle string `json:"clusterTitle"`
}
type DigestCitation struct {
Index int `json:"index"`
URL string `json:"url"`
Title string `json:"title"`
Domain string `json:"domain"`
}
type PreScrapedArticle struct {
Title string
Content string
URL string
}
type ResponsePrefs struct {
Format string `json:"format,omitempty"`
Length string `json:"length,omitempty"`
Tone string `json:"tone,omitempty"`
}
type OrchestratorInput struct {
ChatHistory []llm.Message
FollowUp string
Config OrchestratorConfig
}
func RunOrchestrator(ctx context.Context, sess *session.Session, input OrchestratorInput) error {
detectedLang := detectLanguage(input.FollowUp)
isArticleSummary := strings.HasPrefix(strings.TrimSpace(input.FollowUp), "Summary: ")
if input.Config.FocusMode == "" {
input.Config.FocusMode = DetectFocusMode(input.FollowUp)
}
if input.Config.EnableDeepResearch && input.Config.Mode == ModeQuality {
return runDeepResearchMode(ctx, sess, input, detectedLang)
}
if input.Config.Mode == ModeSpeed && !isArticleSummary {
return runSpeedMode(ctx, sess, input, detectedLang)
}
return runFullMode(ctx, sess, input, detectedLang, isArticleSummary)
}
func runDeepResearchMode(ctx context.Context, sess *session.Session, input OrchestratorInput, lang string) error {
sess.EmitBlock(types.NewResearchBlock(uuid.New().String()))
researcher := NewDeepResearcher(DeepResearchConfig{
LLM: input.Config.LLM,
SearchClient: input.Config.SearchClient,
FocusMode: input.Config.FocusMode,
Locale: input.Config.Locale,
MaxSearchQueries: 30,
MaxSources: 100,
MaxIterations: 5,
Timeout: 5 * time.Minute,
}, sess)
result, err := researcher.Research(ctx, input.FollowUp)
if err != nil {
sess.EmitError(err)
return err
}
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), result.Sources))
if len(result.FollowUpQueries) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
"questions": result.FollowUpQueries,
}))
}
sess.EmitResearchComplete()
sess.EmitEnd()
return nil
}
func generateClarifyingQuestions(ctx context.Context, llmClient llm.Client, query string) ([]string, error) {
prompt := fmt.Sprintf(`Analyze this query and determine if clarifying questions would help provide a better answer.
Query: %s
If the query is:
- Clear and specific → respond with "CLEAR"
- Ambiguous or could benefit from clarification → provide 2-3 short clarifying questions
Format:
CLEAR
or
QUESTION: [question 1]
QUESTION: [question 2]
QUESTION: [question 3]`, query)
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
if strings.Contains(strings.ToUpper(result), "CLEAR") {
return nil, nil
}
var questions []string
for _, line := range strings.Split(result, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "QUESTION:") {
q := strings.TrimSpace(strings.TrimPrefix(line, "QUESTION:"))
if q != "" {
questions = append(questions, q)
}
}
}
return questions, nil
}
func generateRelatedQuestions(ctx context.Context, llmClient llm.Client, query, answer string, locale string) []string {
langInstruction := ""
if locale == "ru" {
langInstruction = "Generate questions in Russian."
}
prompt := fmt.Sprintf(`Based on this query and answer, generate 3-4 related follow-up questions the user might want to explore.
Query: %s
Answer summary: %s
%s
Format: One question per line, no numbering or bullets.`, query, truncateForPrompt(answer, 500), langInstruction)
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil
}
var questions []string
for _, line := range strings.Split(result, "\n") {
line = strings.TrimSpace(line)
if line != "" && len(line) > 10 && strings.Contains(line, "?") {
line = strings.TrimLeft(line, "0123456789.-•* ")
if line != "" {
questions = append(questions, line)
}
}
}
if len(questions) > 4 {
questions = questions[:4]
}
return questions
}
func truncateForPrompt(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func buildEnhancedContext(input OrchestratorInput) string {
var ctx strings.Builder
if input.Config.UserMemory != "" {
ctx.WriteString("## User Preferences\n")
ctx.WriteString(input.Config.UserMemory)
ctx.WriteString("\n\n")
}
if input.Config.CollectionContext != "" {
ctx.WriteString("## Collection Context\n")
ctx.WriteString(input.Config.CollectionContext)
ctx.WriteString("\n\n")
}
if input.Config.FileContext != "" {
ctx.WriteString("## Uploaded Files Content\n")
ctx.WriteString(input.Config.FileContext)
ctx.WriteString("\n\n")
}
if input.Config.MemoryContext != "" {
ctx.WriteString("## Previous Context\n")
ctx.WriteString(input.Config.MemoryContext)
ctx.WriteString("\n\n")
}
return ctx.String()
}
func fetchPreGeneratedDigest(ctx context.Context, discoverURL, articleURL string) (*DigestResponse, error) {
if discoverURL == "" {
return nil, nil
}
reqURL := fmt.Sprintf("%s/api/v1/discover/digest?url=%s",
strings.TrimSuffix(discoverURL, "/"),
url.QueryEscape(articleURL))
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return nil, err
}
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil
}
var digest DigestResponse
if err := json.NewDecoder(resp.Body).Decode(&digest); err != nil {
return nil, err
}
if digest.SummaryRu != "" && len(digest.Citations) > 0 {
return &digest, nil
}
return nil, nil
}
func preScrapeArticleURL(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
if crawl4aiURL != "" {
article, err := scrapeWithCrawl4AI(ctx, crawl4aiURL, articleURL)
if err == nil && article != nil {
return article, nil
}
}
return scrapeDirectly(ctx, articleURL)
}
func scrapeWithCrawl4AI(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
reqBody := fmt.Sprintf(`{
"urls": ["%s"],
"crawler_config": {
"type": "CrawlerRunConfig",
"params": {
"cache_mode": "default",
"page_timeout": 20000
}
}
}`, articleURL)
req, err := http.NewRequestWithContext(ctx, "POST", crawl4aiURL+"/crawl", strings.NewReader(reqBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 25 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Crawl4AI returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
markdown := extractMarkdownFromCrawl4AI(string(body))
title := extractTitleFromCrawl4AI(string(body))
if len(markdown) > 100 {
content := markdown
if len(content) > 15000 {
content = content[:15000]
}
return &PreScrapedArticle{
Title: title,
Content: content,
URL: articleURL,
}, nil
}
return nil, fmt.Errorf("insufficient content from Crawl4AI")
}
func scrapeDirectly(ctx context.Context, articleURL string) (*PreScrapedArticle, error) {
req, err := http.NewRequestWithContext(ctx, "GET", articleURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", "GooSeek-Agent/1.0")
req.Header.Set("Accept", "text/html")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
html := string(body)
title := extractHTMLTitle(html)
content := extractTextContent(html)
if len(content) < 100 {
return nil, fmt.Errorf("insufficient content")
}
if len(content) > 15000 {
content = content[:15000]
}
return &PreScrapedArticle{
Title: title,
Content: content,
URL: articleURL,
}, nil
}
var (
titleRegex = regexp.MustCompile(`<title[^>]*>([^<]+)</title>`)
scriptRegex = regexp.MustCompile(`(?s)<script[^>]*>.*?</script>`)
styleRegex = regexp.MustCompile(`(?s)<style[^>]*>.*?</style>`)
tagRegex = regexp.MustCompile(`<[^>]+>`)
spaceRegex = regexp.MustCompile(`\s+`)
)
func extractHTMLTitle(html string) string {
matches := titleRegex.FindStringSubmatch(html)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
func extractTextContent(html string) string {
bodyStart := strings.Index(strings.ToLower(html), "<body")
bodyEnd := strings.Index(strings.ToLower(html), "</body>")
if bodyStart != -1 && bodyEnd != -1 && bodyEnd > bodyStart {
html = html[bodyStart:bodyEnd]
}
html = scriptRegex.ReplaceAllString(html, "")
html = styleRegex.ReplaceAllString(html, "")
html = tagRegex.ReplaceAllString(html, " ")
html = spaceRegex.ReplaceAllString(html, " ")
return strings.TrimSpace(html)
}
func extractMarkdownFromCrawl4AI(response string) string {
if idx := strings.Index(response, `"raw_markdown"`); idx != -1 {
start := idx + len(`"raw_markdown"`)
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
start += colonIdx + 1
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
start++
}
end := strings.Index(response[start:], `"`)
if end > 0 {
return response[start : start+end]
}
}
}
return ""
}
func extractTitleFromCrawl4AI(response string) string {
if idx := strings.Index(response, `"title"`); idx != -1 {
start := idx + len(`"title"`)
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
start += colonIdx + 1
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
start++
}
end := strings.Index(response[start:], `"`)
if end > 0 {
return response[start : start+end]
}
}
}
return ""
}
func runSpeedMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string) error {
classification := fastClassify(input.FollowUp, input.ChatHistory)
searchQuery := classification.StandaloneFollowUp
if searchQuery == "" {
searchQuery = input.FollowUp
}
queries := generateSearchQueries(searchQuery)
researchBlockID := uuid.New().String()
sess.EmitBlock(types.NewResearchBlock(researchBlockID))
var searchResults []types.Chunk
var mediaResult *search.MediaSearchResult
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
results, err := parallelSearch(gctx, input.Config.SearchClient, queries)
if err != nil {
return nil
}
searchResults = results
return nil
})
g.Go(func() error {
result, err := input.Config.SearchClient.SearchMedia(gctx, searchQuery, &search.MediaSearchOptions{
MaxImages: 6,
MaxVideos: 4,
})
if err != nil {
return nil
}
mediaResult = result
return nil
})
_ = g.Wait()
if len(searchResults) > 0 {
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
}
if mediaResult != nil {
if len(mediaResult.Images) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
"images": mediaResult.Images,
"layout": "carousel",
}))
}
if len(mediaResult.Videos) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
"items": mediaResult.Videos,
"title": "",
}))
}
}
sess.EmitResearchComplete()
queryComplexity := search.EstimateQueryComplexity(searchQuery)
adaptiveTopK := search.ComputeAdaptiveTopK(len(searchResults), queryComplexity, "speed")
rankedResults := search.RerankBM25(searchResults, searchQuery, adaptiveTopK)
finalContext := buildContext(rankedResults, 15, 250)
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
Context: finalContext,
SystemInstructions: input.Config.SystemInstructions,
Mode: string(input.Config.Mode),
Locale: input.Config.Locale,
MemoryContext: input.Config.MemoryContext,
AnswerMode: input.Config.AnswerMode,
DetectedLanguage: detectedLang,
IsArticleSummary: false,
})
messages := []llm.Message{
{Role: llm.RoleSystem, Content: writerPrompt},
}
messages = append(messages, input.ChatHistory...)
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
return streamResponse(ctx, sess, input.Config.LLM, messages, 2048, input.FollowUp, input.Config.Locale)
}
func runFullMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string, isArticleSummary bool) error {
if input.Config.EnableClarifying && !isArticleSummary && input.Config.Mode == ModeQuality {
clarifying, err := generateClarifyingQuestions(ctx, input.Config.LLM, input.FollowUp)
if err == nil && len(clarifying) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "clarifying", map[string]interface{}{
"questions": clarifying,
"query": input.FollowUp,
}))
return nil
}
}
enhancedContext := buildEnhancedContext(input)
if enhancedContext != "" {
input.Config.MemoryContext = enhancedContext + input.Config.MemoryContext
}
var preScrapedArticle *PreScrapedArticle
var articleURL string
if isArticleSummary {
articleURL = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(input.FollowUp), "Summary: "))
digestCtx, digestCancel := context.WithTimeout(ctx, 3*time.Second)
scrapeCtx, scrapeCancel := context.WithTimeout(ctx, 25*time.Second)
digestCh := make(chan *DigestResponse, 1)
scrapeCh := make(chan *PreScrapedArticle, 1)
go func() {
defer digestCancel()
digest, _ := fetchPreGeneratedDigest(digestCtx, input.Config.DiscoverSvcURL, articleURL)
digestCh <- digest
}()
go func() {
defer scrapeCancel()
article, _ := preScrapeArticleURL(scrapeCtx, input.Config.Crawl4AIURL, articleURL)
scrapeCh <- article
}()
digest := <-digestCh
preScrapedArticle = <-scrapeCh
if digest != nil {
chunks := make([]types.Chunk, len(digest.Citations))
for i, c := range digest.Citations {
chunks[i] = types.Chunk{
Content: c.Title,
Metadata: map[string]string{
"url": c.URL,
"title": c.Title,
"domain": c.Domain,
},
}
}
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), chunks))
sess.EmitResearchComplete()
summaryText := digest.SummaryRu
if len(digest.FollowUp) > 0 {
summaryText += "\n\n---\n"
for _, q := range digest.FollowUp {
summaryText += "> " + q + "\n"
}
}
sess.EmitBlock(types.NewTextBlock(uuid.New().String(), summaryText))
sess.EmitEnd()
return nil
}
}
classification, err := classify(ctx, input.Config.LLM, input.FollowUp, input.ChatHistory, input.Config.Locale, detectedLang)
if err != nil {
classification = &ClassificationResult{
StandaloneFollowUp: input.FollowUp,
SkipSearch: false,
}
}
if isArticleSummary && classification.SkipSearch {
classification.SkipSearch = false
}
g, gctx := errgroup.WithContext(ctx)
var searchResults []types.Chunk
var mediaResult *search.MediaSearchResult
mediaQuery := classification.StandaloneFollowUp
if mediaQuery == "" {
mediaQuery = input.FollowUp
}
effectiveFollowUp := input.FollowUp
if isArticleSummary && preScrapedArticle != nil && preScrapedArticle.Title != "" {
effectiveFollowUp = fmt.Sprintf("Summary: %s\nArticle title: %s", preScrapedArticle.URL, preScrapedArticle.Title)
if classification.StandaloneFollowUp != "" {
classification.StandaloneFollowUp = preScrapedArticle.Title + " " + classification.StandaloneFollowUp
} else {
classification.StandaloneFollowUp = preScrapedArticle.Title
}
}
if !classification.SkipSearch {
g.Go(func() error {
results, err := research(gctx, sess, input.Config.LLM, input.Config.SearchClient, ResearchInput{
ChatHistory: input.ChatHistory,
FollowUp: effectiveFollowUp,
Classification: classification,
Mode: input.Config.Mode,
Sources: input.Config.Sources,
Locale: input.Config.Locale,
DetectedLang: detectedLang,
IsArticleSummary: isArticleSummary,
})
if err != nil {
return nil
}
searchResults = results
return nil
})
}
if !isArticleSummary {
g.Go(func() error {
result, err := input.Config.SearchClient.SearchMedia(gctx, mediaQuery, &search.MediaSearchOptions{
MaxImages: 8,
MaxVideos: 6,
})
if err != nil {
return nil
}
mediaResult = result
return nil
})
}
_ = g.Wait()
if isArticleSummary && preScrapedArticle != nil {
alreadyHasURL := false
for _, r := range searchResults {
if strings.Contains(r.Metadata["url"], preScrapedArticle.URL) {
alreadyHasURL = true
break
}
}
if !alreadyHasURL {
prependChunk := types.Chunk{
Content: preScrapedArticle.Content,
Metadata: map[string]string{
"url": preScrapedArticle.URL,
"title": preScrapedArticle.Title,
},
}
searchResults = append([]types.Chunk{prependChunk}, searchResults...)
}
}
if len(searchResults) > 0 {
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
}
if mediaResult != nil {
if len(mediaResult.Images) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
"images": mediaResult.Images,
"layout": "carousel",
}))
}
if len(mediaResult.Videos) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
"items": mediaResult.Videos,
"title": "",
}))
}
}
sess.EmitResearchComplete()
maxResults := 25
maxContent := 320
if isArticleSummary {
maxResults = 30
maxContent = 2000
}
rankedResults := rankByRelevance(searchResults, input.FollowUp)
if len(rankedResults) > maxResults {
rankedResults = rankedResults[:maxResults]
}
finalContext := buildContext(rankedResults, maxResults, maxContent)
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
Context: finalContext,
SystemInstructions: input.Config.SystemInstructions,
Mode: string(input.Config.Mode),
Locale: input.Config.Locale,
MemoryContext: input.Config.MemoryContext,
AnswerMode: input.Config.AnswerMode,
DetectedLanguage: detectedLang,
IsArticleSummary: isArticleSummary,
})
messages := []llm.Message{
{Role: llm.RoleSystem, Content: writerPrompt},
}
messages = append(messages, input.ChatHistory...)
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
maxTokens := 4096
return streamResponse(ctx, sess, input.Config.LLM, messages, maxTokens, input.FollowUp, input.Config.Locale)
}
func streamResponse(ctx context.Context, sess *session.Session, client llm.Client, messages []llm.Message, maxTokens int, query string, locale string) error {
stream, err := client.StreamText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{MaxTokens: maxTokens},
})
if err != nil {
return err
}
var responseBlockID string
var accumulatedText string
for chunk := range stream {
if chunk.ContentChunk == "" && responseBlockID == "" {
continue
}
if responseBlockID == "" {
responseBlockID = uuid.New().String()
accumulatedText = chunk.ContentChunk
sess.EmitBlock(types.NewTextBlock(responseBlockID, accumulatedText))
} else if chunk.ContentChunk != "" {
accumulatedText += chunk.ContentChunk
sess.EmitTextChunk(responseBlockID, chunk.ContentChunk)
}
}
if responseBlockID != "" {
sess.UpdateBlock(responseBlockID, []session.Patch{
{Op: "replace", Path: "/data", Value: accumulatedText},
})
}
go func() {
relatedCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
related := generateRelatedQuestions(relatedCtx, client, query, accumulatedText, locale)
if len(related) > 0 {
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
"questions": related,
}))
}
}()
sess.EmitEnd()
return nil
}
func parallelSearch(ctx context.Context, client *search.SearXNGClient, queries []string) ([]types.Chunk, error) {
results := make([]types.Chunk, 0)
seen := make(map[string]bool)
g, gctx := errgroup.WithContext(ctx)
resultsCh := make(chan []types.SearchResult, len(queries))
for _, q := range queries {
query := q
g.Go(func() error {
resp, err := client.Search(gctx, query, &search.SearchOptions{
Categories: []string{"general", "news"},
PageNo: 1,
})
if err != nil {
resultsCh <- nil
return nil
}
resultsCh <- resp.Results
return nil
})
}
go func() {
g.Wait()
close(resultsCh)
}()
for batch := range resultsCh {
for _, r := range batch {
if r.URL != "" && !seen[r.URL] {
seen[r.URL] = true
results = append(results, r.ToChunk())
}
}
}
return results, nil
}
func buildContext(chunks []types.Chunk, maxResults, maxContentLen int) string {
if len(chunks) > maxResults {
chunks = chunks[:maxResults]
}
var sb strings.Builder
sb.WriteString("<search_results note=\"These are the search results and assistant can cite these\">\n")
for i, chunk := range chunks {
content := chunk.Content
if len(content) > maxContentLen {
content = content[:maxContentLen] + "…"
}
title := chunk.Metadata["title"]
sb.WriteString("<result index=")
sb.WriteString(strings.ReplaceAll(title, "\"", "'"))
sb.WriteString("\" index=\"")
sb.WriteString(string(rune('0' + i + 1)))
sb.WriteString("\">")
sb.WriteString(content)
sb.WriteString("</result>\n")
}
sb.WriteString("</search_results>")
return sb.String()
}
func rankByRelevance(chunks []types.Chunk, query string) []types.Chunk {
if len(chunks) == 0 {
return chunks
}
terms := extractQueryTerms(query)
if len(terms) == 0 {
return chunks
}
type scored struct {
chunk types.Chunk
score int
}
scored_chunks := make([]scored, len(chunks))
for i, chunk := range chunks {
score := 0
content := strings.ToLower(chunk.Content)
title := strings.ToLower(chunk.Metadata["title"])
for term := range terms {
if strings.Contains(title, term) {
score += 3
}
if strings.Contains(content, term) {
score += 1
}
}
scored_chunks[i] = scored{chunk: chunk, score: score}
}
for i := 0; i < len(scored_chunks)-1; i++ {
for j := i + 1; j < len(scored_chunks); j++ {
if scored_chunks[j].score > scored_chunks[i].score {
scored_chunks[i], scored_chunks[j] = scored_chunks[j], scored_chunks[i]
}
}
}
result := make([]types.Chunk, len(scored_chunks))
for i, s := range scored_chunks {
result[i] = s.chunk
}
return result
}
func extractQueryTerms(query string) map[string]bool {
query = strings.ToLower(query)
query = strings.TrimPrefix(query, "summary: ")
words := strings.Fields(query)
terms := make(map[string]bool)
for _, w := range words {
if len(w) >= 2 && !strings.HasPrefix(w, "http") {
terms[w] = true
}
}
return terms
}

View File

@@ -0,0 +1,128 @@
package agent
import (
"context"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/search"
"github.com/gooseek/backend/internal/session"
"github.com/gooseek/backend/internal/types"
"github.com/google/uuid"
)
type ResearchInput struct {
ChatHistory []llm.Message
FollowUp string
Classification *ClassificationResult
Mode Mode
Sources []string
Locale string
DetectedLang string
IsArticleSummary bool
}
func research(
ctx context.Context,
sess *session.Session,
llmClient llm.Client,
searchClient *search.SearXNGClient,
input ResearchInput,
) ([]types.Chunk, error) {
maxIterations := 1
switch input.Mode {
case ModeBalanced:
maxIterations = 3
case ModeQuality:
maxIterations = 10
}
researchBlockID := uuid.New().String()
sess.EmitBlock(types.NewResearchBlock(researchBlockID))
allResults := make([]types.Chunk, 0)
seenURLs := make(map[string]bool)
searchQuery := input.Classification.StandaloneFollowUp
if searchQuery == "" {
searchQuery = input.FollowUp
}
for i := 0; i < maxIterations; i++ {
queries := generateSearchQueries(searchQuery)
sess.UpdateBlock(researchBlockID, []session.Patch{
{
Op: "replace",
Path: "/data/subSteps",
Value: []types.ResearchSubStep{
{
ID: uuid.New().String(),
Type: "searching",
Searching: queries,
},
},
},
})
for _, q := range queries {
resp, err := searchClient.Search(ctx, q, &search.SearchOptions{
Categories: categoriesToSearch(input.Sources),
PageNo: 1,
})
if err != nil {
continue
}
for _, r := range resp.Results {
if r.URL != "" && !seenURLs[r.URL] {
seenURLs[r.URL] = true
allResults = append(allResults, r.ToChunk())
}
}
}
if input.Mode == ModeSpeed {
break
}
if len(allResults) >= 20 && input.Mode == ModeBalanced {
break
}
if len(allResults) >= 50 {
break
}
}
return allResults, nil
}
func categoriesToSearch(sources []string) []string {
if len(sources) == 0 {
return []string{"general", "news"}
}
categories := make([]string, 0)
for _, s := range sources {
switch s {
case "web":
categories = append(categories, "general")
case "discussions":
categories = append(categories, "social media")
case "academic":
categories = append(categories, "science")
case "news":
categories = append(categories, "news")
case "images":
categories = append(categories, "images")
case "videos":
categories = append(categories, "videos")
}
}
if len(categories) == 0 {
return []string{"general"}
}
return categories
}

View File

@@ -0,0 +1,587 @@
package browser
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type PlaywrightBrowser struct {
cmd *exec.Cmd
serverURL string
client *http.Client
sessions map[string]*BrowserSession
mu sync.RWMutex
config Config
}
type Config struct {
PlaywrightServerURL string
DefaultTimeout time.Duration
Headless bool
UserAgent string
ProxyURL string
ScreenshotsDir string
RecordingsDir string
}
type BrowserSession struct {
ID string
ContextID string
PageID string
CreatedAt time.Time
LastAction time.Time
Screenshots []string
Recordings []string
Closed bool
}
type ActionRequest struct {
SessionID string `json:"sessionId"`
Action string `json:"action"`
Params map[string]interface{} `json:"params"`
}
type ActionResponse struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Screenshot string `json:"screenshot,omitempty"`
Error string `json:"error,omitempty"`
PageTitle string `json:"pageTitle,omitempty"`
PageURL string `json:"pageUrl,omitempty"`
}
func NewPlaywrightBrowser(cfg Config) *PlaywrightBrowser {
if cfg.DefaultTimeout == 0 {
cfg.DefaultTimeout = 30 * time.Second
}
if cfg.PlaywrightServerURL == "" {
cfg.PlaywrightServerURL = "http://localhost:3050"
}
if cfg.ScreenshotsDir == "" {
cfg.ScreenshotsDir = "/tmp/gooseek-screenshots"
}
if cfg.RecordingsDir == "" {
cfg.RecordingsDir = "/tmp/gooseek-recordings"
}
os.MkdirAll(cfg.ScreenshotsDir, 0755)
os.MkdirAll(cfg.RecordingsDir, 0755)
return &PlaywrightBrowser{
serverURL: cfg.PlaywrightServerURL,
client: &http.Client{
Timeout: cfg.DefaultTimeout,
},
sessions: make(map[string]*BrowserSession),
config: cfg,
}
}
func (b *PlaywrightBrowser) NewSession(ctx context.Context, opts SessionOptions) (*BrowserSession, error) {
sessionID := uuid.New().String()
params := map[string]interface{}{
"headless": b.config.Headless,
"sessionId": sessionID,
}
if opts.Viewport != nil {
params["viewport"] = opts.Viewport
}
if opts.UserAgent != "" {
params["userAgent"] = opts.UserAgent
} else if b.config.UserAgent != "" {
params["userAgent"] = b.config.UserAgent
}
if opts.ProxyURL != "" {
params["proxy"] = opts.ProxyURL
} else if b.config.ProxyURL != "" {
params["proxy"] = b.config.ProxyURL
}
if opts.RecordVideo {
params["recordVideo"] = map[string]interface{}{
"dir": b.config.RecordingsDir,
}
}
resp, err := b.sendCommand(ctx, "browser.newContext", params)
if err != nil {
return nil, fmt.Errorf("failed to create browser context: %w", err)
}
contextID, _ := resp["contextId"].(string)
pageID, _ := resp["pageId"].(string)
session := &BrowserSession{
ID: sessionID,
ContextID: contextID,
PageID: pageID,
CreatedAt: time.Now(),
LastAction: time.Now(),
}
b.mu.Lock()
b.sessions[sessionID] = session
b.mu.Unlock()
return session, nil
}
func (b *PlaywrightBrowser) CloseSession(ctx context.Context, sessionID string) error {
b.mu.Lock()
session, ok := b.sessions[sessionID]
if !ok {
b.mu.Unlock()
return errors.New("session not found")
}
session.Closed = true
delete(b.sessions, sessionID)
b.mu.Unlock()
_, err := b.sendCommand(ctx, "browser.closeContext", map[string]interface{}{
"sessionId": sessionID,
})
return err
}
func (b *PlaywrightBrowser) Navigate(ctx context.Context, sessionID, url string, opts NavigateOptions) (*ActionResponse, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"url": url,
}
if opts.Timeout > 0 {
params["timeout"] = opts.Timeout
}
if opts.WaitUntil != "" {
params["waitUntil"] = opts.WaitUntil
}
resp, err := b.sendCommand(ctx, "page.goto", params)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
result := &ActionResponse{
Success: true,
PageURL: getString(resp, "url"),
PageTitle: getString(resp, "title"),
}
if opts.Screenshot {
screenshot, _ := b.Screenshot(ctx, sessionID, ScreenshotOptions{FullPage: false})
if screenshot != nil {
result.Screenshot = screenshot.Data
}
}
return result, nil
}
func (b *PlaywrightBrowser) Click(ctx context.Context, sessionID, selector string, opts ClickOptions) (*ActionResponse, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
}
if opts.Button != "" {
params["button"] = opts.Button
}
if opts.ClickCount > 0 {
params["clickCount"] = opts.ClickCount
}
if opts.Timeout > 0 {
params["timeout"] = opts.Timeout
}
if opts.Force {
params["force"] = true
}
_, err := b.sendCommand(ctx, "page.click", params)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
result := &ActionResponse{Success: true}
if opts.WaitAfter > 0 {
time.Sleep(time.Duration(opts.WaitAfter) * time.Millisecond)
}
if opts.Screenshot {
screenshot, _ := b.Screenshot(ctx, sessionID, ScreenshotOptions{FullPage: false})
if screenshot != nil {
result.Screenshot = screenshot.Data
}
}
return result, nil
}
func (b *PlaywrightBrowser) Type(ctx context.Context, sessionID, selector, text string, opts TypeOptions) (*ActionResponse, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
"text": text,
}
if opts.Delay > 0 {
params["delay"] = opts.Delay
}
if opts.Timeout > 0 {
params["timeout"] = opts.Timeout
}
if opts.Clear {
b.sendCommand(ctx, "page.fill", map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
"value": "",
})
}
_, err := b.sendCommand(ctx, "page.type", params)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
return &ActionResponse{Success: true}, nil
}
func (b *PlaywrightBrowser) Fill(ctx context.Context, sessionID, selector, value string) (*ActionResponse, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
"value": value,
}
_, err := b.sendCommand(ctx, "page.fill", params)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
return &ActionResponse{Success: true}, nil
}
func (b *PlaywrightBrowser) Screenshot(ctx context.Context, sessionID string, opts ScreenshotOptions) (*ScreenshotResult, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"fullPage": opts.FullPage,
}
if opts.Selector != "" {
params["selector"] = opts.Selector
}
if opts.Quality > 0 {
params["quality"] = opts.Quality
}
params["type"] = "png"
if opts.Format != "" {
params["type"] = opts.Format
}
resp, err := b.sendCommand(ctx, "page.screenshot", params)
if err != nil {
return nil, err
}
data, _ := resp["data"].(string)
filename := fmt.Sprintf("%s/%s-%d.png", b.config.ScreenshotsDir, sessionID, time.Now().UnixNano())
if decoded, err := base64.StdEncoding.DecodeString(data); err == nil {
os.WriteFile(filename, decoded, 0644)
}
b.mu.Lock()
if session, ok := b.sessions[sessionID]; ok {
session.Screenshots = append(session.Screenshots, filename)
}
b.mu.Unlock()
return &ScreenshotResult{
Data: data,
Path: filename,
MimeType: "image/png",
}, nil
}
func (b *PlaywrightBrowser) ExtractText(ctx context.Context, sessionID, selector string) (string, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
}
resp, err := b.sendCommand(ctx, "page.textContent", params)
if err != nil {
return "", err
}
return getString(resp, "text"), nil
}
func (b *PlaywrightBrowser) ExtractHTML(ctx context.Context, sessionID, selector string) (string, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
}
resp, err := b.sendCommand(ctx, "page.innerHTML", params)
if err != nil {
return "", err
}
return getString(resp, "html"), nil
}
func (b *PlaywrightBrowser) WaitForSelector(ctx context.Context, sessionID, selector string, opts WaitOptions) error {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
}
if opts.Timeout > 0 {
params["timeout"] = opts.Timeout
}
if opts.State != "" {
params["state"] = opts.State
}
_, err := b.sendCommand(ctx, "page.waitForSelector", params)
return err
}
func (b *PlaywrightBrowser) WaitForNavigation(ctx context.Context, sessionID string, opts WaitOptions) error {
params := map[string]interface{}{
"sessionId": sessionID,
}
if opts.Timeout > 0 {
params["timeout"] = opts.Timeout
}
if opts.WaitUntil != "" {
params["waitUntil"] = opts.WaitUntil
}
_, err := b.sendCommand(ctx, "page.waitForNavigation", params)
return err
}
func (b *PlaywrightBrowser) Scroll(ctx context.Context, sessionID string, opts ScrollOptions) (*ActionResponse, error) {
script := fmt.Sprintf("window.scrollBy(%d, %d)", opts.X, opts.Y)
if opts.Selector != "" {
script = fmt.Sprintf(`document.querySelector('%s').scrollBy(%d, %d)`, opts.Selector, opts.X, opts.Y)
}
if opts.ToBottom {
script = "window.scrollTo(0, document.body.scrollHeight)"
}
if opts.ToTop {
script = "window.scrollTo(0, 0)"
}
_, err := b.Evaluate(ctx, sessionID, script)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
if opts.WaitAfter > 0 {
time.Sleep(time.Duration(opts.WaitAfter) * time.Millisecond)
}
return &ActionResponse{Success: true}, nil
}
func (b *PlaywrightBrowser) Evaluate(ctx context.Context, sessionID, script string) (interface{}, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"expression": script,
}
resp, err := b.sendCommand(ctx, "page.evaluate", params)
if err != nil {
return nil, err
}
return resp["result"], nil
}
func (b *PlaywrightBrowser) Select(ctx context.Context, sessionID, selector string, values []string) (*ActionResponse, error) {
params := map[string]interface{}{
"sessionId": sessionID,
"selector": selector,
"values": values,
}
_, err := b.sendCommand(ctx, "page.selectOption", params)
if err != nil {
return &ActionResponse{Success: false, Error: err.Error()}, err
}
return &ActionResponse{Success: true}, nil
}
func (b *PlaywrightBrowser) GetPageInfo(ctx context.Context, sessionID string) (*PageInfo, error) {
params := map[string]interface{}{
"sessionId": sessionID,
}
resp, err := b.sendCommand(ctx, "page.info", params)
if err != nil {
return nil, err
}
return &PageInfo{
URL: getString(resp, "url"),
Title: getString(resp, "title"),
Content: getString(resp, "content"),
}, nil
}
func (b *PlaywrightBrowser) PDF(ctx context.Context, sessionID string, opts PDFOptions) ([]byte, error) {
params := map[string]interface{}{
"sessionId": sessionID,
}
if opts.Format != "" {
params["format"] = opts.Format
}
if opts.Landscape {
params["landscape"] = true
}
if opts.PrintBackground {
params["printBackground"] = true
}
resp, err := b.sendCommand(ctx, "page.pdf", params)
if err != nil {
return nil, err
}
data, _ := resp["data"].(string)
return base64.StdEncoding.DecodeString(data)
}
func (b *PlaywrightBrowser) sendCommand(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
body := map[string]interface{}{
"method": method,
"params": params,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", b.serverURL+"/api/browser", strings.NewReader(string(jsonBody)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := b.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
if errMsg, ok := result["error"].(string); ok && errMsg != "" {
return result, errors.New(errMsg)
}
return result, nil
}
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
type SessionOptions struct {
Headless bool
Viewport *Viewport
UserAgent string
ProxyURL string
RecordVideo bool
BlockAds bool
}
type Viewport struct {
Width int `json:"width"`
Height int `json:"height"`
}
type NavigateOptions struct {
Timeout int
WaitUntil string
Screenshot bool
}
type ClickOptions struct {
Button string
ClickCount int
Timeout int
Force bool
WaitAfter int
Screenshot bool
}
type TypeOptions struct {
Delay int
Timeout int
Clear bool
}
type ScreenshotOptions struct {
FullPage bool
Selector string
Format string
Quality int
}
type ScreenshotResult struct {
Data string
Path string
MimeType string
}
type WaitOptions struct {
Timeout int
State string
WaitUntil string
}
type ScrollOptions struct {
X int
Y int
Selector string
ToBottom bool
ToTop bool
WaitAfter int
}
type PageInfo struct {
URL string
Title string
Content string
}
type PDFOptions struct {
Format string
Landscape bool
PrintBackground bool
}

View File

@@ -0,0 +1,555 @@
package browser
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
)
type BrowserServer struct {
browser *PlaywrightBrowser
sessions map[string]*ManagedSession
mu sync.RWMutex
config ServerConfig
}
type ServerConfig struct {
Port int
MaxSessions int
SessionTimeout time.Duration
CleanupInterval time.Duration
}
type ManagedSession struct {
*BrowserSession
LastActive time.Time
Actions []ActionLog
}
type ActionLog struct {
Action string `json:"action"`
Params string `json:"params"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Duration int64 `json:"durationMs"`
Timestamp time.Time `json:"timestamp"`
}
type BrowserRequest struct {
Method string `json:"method"`
Params map[string]interface{} `json:"params"`
}
func NewBrowserServer(cfg ServerConfig) *BrowserServer {
if cfg.Port == 0 {
cfg.Port = 3050
}
if cfg.MaxSessions == 0 {
cfg.MaxSessions = 20
}
if cfg.SessionTimeout == 0 {
cfg.SessionTimeout = 30 * time.Minute
}
if cfg.CleanupInterval == 0 {
cfg.CleanupInterval = 5 * time.Minute
}
return &BrowserServer{
browser: NewPlaywrightBrowser(Config{
DefaultTimeout: 30 * time.Second,
Headless: true,
}),
sessions: make(map[string]*ManagedSession),
config: cfg,
}
}
func (s *BrowserServer) Start(ctx context.Context) error {
go s.cleanupLoop(ctx)
app := fiber.New(fiber.Config{
BodyLimit: 50 * 1024 * 1024,
ReadTimeout: 2 * time.Minute,
WriteTimeout: 2 * time.Minute,
})
app.Use(logger.New())
app.Use(cors.New())
app.Get("/health", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{"status": "ok", "sessions": len(s.sessions)})
})
app.Post("/api/browser", s.handleBrowserCommand)
app.Post("/api/session/new", s.handleNewSession)
app.Delete("/api/session/:id", s.handleCloseSession)
app.Get("/api/session/:id", s.handleGetSession)
app.Get("/api/sessions", s.handleListSessions)
app.Post("/api/action", s.handleAction)
log.Printf("[BrowserServer] Starting on port %d", s.config.Port)
return app.Listen(fmt.Sprintf(":%d", s.config.Port))
}
func (s *BrowserServer) handleBrowserCommand(c *fiber.Ctx) error {
var req BrowserRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid request"})
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
sessionID, _ := req.Params["sessionId"].(string)
s.mu.Lock()
if session, ok := s.sessions[sessionID]; ok {
session.LastActive = time.Now()
}
s.mu.Unlock()
start := time.Now()
result, err := s.executeMethod(ctx, req.Method, req.Params)
s.mu.Lock()
if session, ok := s.sessions[sessionID]; ok {
paramsJSON, _ := json.Marshal(req.Params)
session.Actions = append(session.Actions, ActionLog{
Action: req.Method,
Params: string(paramsJSON),
Success: err == nil,
Error: errToString(err),
Duration: time.Since(start).Milliseconds(),
Timestamp: time.Now(),
})
}
s.mu.Unlock()
if err != nil {
return c.JSON(fiber.Map{
"success": false,
"error": err.Error(),
})
}
return c.JSON(result)
}
func (s *BrowserServer) executeMethod(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
sessionID, _ := params["sessionId"].(string)
switch method {
case "browser.newContext":
opts := SessionOptions{
Headless: getBool(params, "headless"),
}
if viewport, ok := params["viewport"].(map[string]interface{}); ok {
opts.Viewport = &Viewport{
Width: getInt(viewport, "width"),
Height: getInt(viewport, "height"),
}
}
if ua, ok := params["userAgent"].(string); ok {
opts.UserAgent = ua
}
if proxy, ok := params["proxy"].(string); ok {
opts.ProxyURL = proxy
}
if rv, ok := params["recordVideo"].(map[string]interface{}); ok {
_ = rv
opts.RecordVideo = true
}
session, err := s.browser.NewSession(ctx, opts)
if err != nil {
return nil, err
}
s.mu.Lock()
s.sessions[session.ID] = &ManagedSession{
BrowserSession: session,
LastActive: time.Now(),
Actions: make([]ActionLog, 0),
}
s.mu.Unlock()
return map[string]interface{}{
"sessionId": session.ID,
"contextId": session.ContextID,
"pageId": session.PageID,
}, nil
case "browser.closeContext":
err := s.browser.CloseSession(ctx, sessionID)
s.mu.Lock()
delete(s.sessions, sessionID)
s.mu.Unlock()
return map[string]interface{}{"success": err == nil}, err
case "page.goto":
url, _ := params["url"].(string)
opts := NavigateOptions{
Timeout: getInt(params, "timeout"),
WaitUntil: getString(params, "waitUntil"),
}
result, err := s.browser.Navigate(ctx, sessionID, url, opts)
if err != nil {
return nil, err
}
return map[string]interface{}{
"success": result.Success,
"url": result.PageURL,
"title": result.PageTitle,
}, nil
case "page.click":
selector, _ := params["selector"].(string)
opts := ClickOptions{
Button: getString(params, "button"),
ClickCount: getInt(params, "clickCount"),
Timeout: getInt(params, "timeout"),
Force: getBool(params, "force"),
}
result, err := s.browser.Click(ctx, sessionID, selector, opts)
if err != nil {
return nil, err
}
return map[string]interface{}{
"success": result.Success,
"screenshot": result.Screenshot,
}, nil
case "page.type":
selector, _ := params["selector"].(string)
text, _ := params["text"].(string)
opts := TypeOptions{
Delay: getInt(params, "delay"),
Timeout: getInt(params, "timeout"),
}
_, err := s.browser.Type(ctx, sessionID, selector, text, opts)
return map[string]interface{}{"success": err == nil}, err
case "page.fill":
selector, _ := params["selector"].(string)
value, _ := params["value"].(string)
_, err := s.browser.Fill(ctx, sessionID, selector, value)
return map[string]interface{}{"success": err == nil}, err
case "page.screenshot":
opts := ScreenshotOptions{
FullPage: getBool(params, "fullPage"),
Selector: getString(params, "selector"),
Format: getString(params, "type"),
Quality: getInt(params, "quality"),
}
result, err := s.browser.Screenshot(ctx, sessionID, opts)
if err != nil {
return nil, err
}
return map[string]interface{}{
"data": result.Data,
"path": result.Path,
}, nil
case "page.textContent":
selector, _ := params["selector"].(string)
text, err := s.browser.ExtractText(ctx, sessionID, selector)
return map[string]interface{}{"text": text}, err
case "page.innerHTML":
selector, _ := params["selector"].(string)
html, err := s.browser.ExtractHTML(ctx, sessionID, selector)
return map[string]interface{}{"html": html}, err
case "page.waitForSelector":
selector, _ := params["selector"].(string)
opts := WaitOptions{
Timeout: getInt(params, "timeout"),
State: getString(params, "state"),
}
err := s.browser.WaitForSelector(ctx, sessionID, selector, opts)
return map[string]interface{}{"success": err == nil}, err
case "page.waitForNavigation":
opts := WaitOptions{
Timeout: getInt(params, "timeout"),
WaitUntil: getString(params, "waitUntil"),
}
err := s.browser.WaitForNavigation(ctx, sessionID, opts)
return map[string]interface{}{"success": err == nil}, err
case "page.evaluate":
expression, _ := params["expression"].(string)
result, err := s.browser.Evaluate(ctx, sessionID, expression)
return map[string]interface{}{"result": result}, err
case "page.selectOption":
selector, _ := params["selector"].(string)
values := getStringArray(params, "values")
_, err := s.browser.Select(ctx, sessionID, selector, values)
return map[string]interface{}{"success": err == nil}, err
case "page.info":
info, err := s.browser.GetPageInfo(ctx, sessionID)
if err != nil {
return nil, err
}
return map[string]interface{}{
"url": info.URL,
"title": info.Title,
"content": info.Content,
}, nil
case "page.pdf":
opts := PDFOptions{
Format: getString(params, "format"),
Landscape: getBool(params, "landscape"),
PrintBackground: getBool(params, "printBackground"),
}
data, err := s.browser.PDF(ctx, sessionID, opts)
if err != nil {
return nil, err
}
return map[string]interface{}{
"data": data,
}, nil
default:
return nil, fmt.Errorf("unknown method: %s", method)
}
}
func (s *BrowserServer) handleNewSession(c *fiber.Ctx) error {
var req struct {
Headless bool `json:"headless"`
Viewport *Viewport `json:"viewport,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
ProxyURL string `json:"proxyUrl,omitempty"`
}
if err := c.BodyParser(&req); err != nil {
req.Headless = true
}
s.mu.RLock()
if len(s.sessions) >= s.config.MaxSessions {
s.mu.RUnlock()
return c.Status(http.StatusTooManyRequests).JSON(fiber.Map{
"error": "Maximum sessions limit reached",
})
}
s.mu.RUnlock()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
session, err := s.browser.NewSession(ctx, SessionOptions{
Headless: req.Headless,
Viewport: req.Viewport,
UserAgent: req.UserAgent,
ProxyURL: req.ProxyURL,
})
if err != nil {
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
}
s.mu.Lock()
s.sessions[session.ID] = &ManagedSession{
BrowserSession: session,
LastActive: time.Now(),
Actions: make([]ActionLog, 0),
}
s.mu.Unlock()
return c.JSON(fiber.Map{
"sessionId": session.ID,
"contextId": session.ContextID,
"pageId": session.PageID,
})
}
func (s *BrowserServer) handleCloseSession(c *fiber.Ctx) error {
sessionID := c.Params("id")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := s.browser.CloseSession(ctx, sessionID)
if err != nil {
return c.Status(404).JSON(fiber.Map{"error": err.Error()})
}
s.mu.Lock()
delete(s.sessions, sessionID)
s.mu.Unlock()
return c.JSON(fiber.Map{"success": true})
}
func (s *BrowserServer) handleGetSession(c *fiber.Ctx) error {
sessionID := c.Params("id")
s.mu.RLock()
session, ok := s.sessions[sessionID]
s.mu.RUnlock()
if !ok {
return c.Status(404).JSON(fiber.Map{"error": "Session not found"})
}
return c.JSON(fiber.Map{
"sessionId": session.ID,
"createdAt": session.CreatedAt,
"lastActive": session.LastActive,
"screenshots": session.Screenshots,
"actions": len(session.Actions),
})
}
func (s *BrowserServer) handleListSessions(c *fiber.Ctx) error {
s.mu.RLock()
defer s.mu.RUnlock()
sessions := make([]map[string]interface{}, 0, len(s.sessions))
for _, session := range s.sessions {
sessions = append(sessions, map[string]interface{}{
"sessionId": session.ID,
"createdAt": session.CreatedAt,
"lastActive": session.LastActive,
"actions": len(session.Actions),
})
}
return c.JSON(fiber.Map{"sessions": sessions, "count": len(sessions)})
}
func (s *BrowserServer) handleAction(c *fiber.Ctx) error {
var req struct {
SessionID string `json:"sessionId"`
Action string `json:"action"`
Selector string `json:"selector,omitempty"`
URL string `json:"url,omitempty"`
Value string `json:"value,omitempty"`
Screenshot bool `json:"screenshot"`
}
if err := c.BodyParser(&req); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid request"})
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
s.mu.Lock()
if session, ok := s.sessions[req.SessionID]; ok {
session.LastActive = time.Now()
}
s.mu.Unlock()
var result *ActionResponse
var err error
switch req.Action {
case "navigate":
result, err = s.browser.Navigate(ctx, req.SessionID, req.URL, NavigateOptions{Screenshot: req.Screenshot})
case "click":
result, err = s.browser.Click(ctx, req.SessionID, req.Selector, ClickOptions{Screenshot: req.Screenshot})
case "type":
result, err = s.browser.Type(ctx, req.SessionID, req.Selector, req.Value, TypeOptions{})
case "fill":
result, err = s.browser.Fill(ctx, req.SessionID, req.Selector, req.Value)
case "screenshot":
var screenshot *ScreenshotResult
screenshot, err = s.browser.Screenshot(ctx, req.SessionID, ScreenshotOptions{})
if err == nil {
result = &ActionResponse{Success: true, Screenshot: screenshot.Data}
}
case "extract":
var text string
text, err = s.browser.ExtractText(ctx, req.SessionID, req.Selector)
result = &ActionResponse{Success: err == nil, Data: text}
default:
return c.Status(400).JSON(fiber.Map{"error": "Unknown action: " + req.Action})
}
if err != nil {
return c.Status(500).JSON(fiber.Map{"error": err.Error(), "success": false})
}
return c.JSON(result)
}
func (s *BrowserServer) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(s.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.cleanupExpiredSessions()
}
}
}
func (s *BrowserServer) cleanupExpiredSessions() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for sessionID, session := range s.sessions {
if now.Sub(session.LastActive) > s.config.SessionTimeout {
log.Printf("[BrowserServer] Cleaning up expired session: %s", sessionID)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
s.browser.CloseSession(ctx, sessionID)
cancel()
delete(s.sessions, sessionID)
}
}
}
func errToString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func getBool(m map[string]interface{}, key string) bool {
if v, ok := m[key].(bool); ok {
return v
}
return false
}
func getInt(m map[string]interface{}, key string) int {
if v, ok := m[key].(float64); ok {
return int(v)
}
if v, ok := m[key].(int); ok {
return v
}
return 0
}
func getStringArray(m map[string]interface{}, key string) []string {
if v, ok := m[key].([]interface{}); ok {
result := make([]string, len(v))
for i, item := range v {
result[i], _ = item.(string)
}
return result
}
return nil
}

View File

@@ -0,0 +1,738 @@
package computer
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/gooseek/backend/internal/computer/connectors"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
)
type ComputerConfig struct {
MaxParallelTasks int
MaxSubTasks int
TaskTimeout time.Duration
SubTaskTimeout time.Duration
TotalBudget float64
EnableSandbox bool
EnableScheduling bool
EnableBrowser bool
SandboxImage string
ArtifactStorageURL string
BrowserServerURL string
CheckpointStorePath string
MaxConcurrentTasks int
HeartbeatInterval time.Duration
CheckpointInterval time.Duration
}
func DefaultConfig() ComputerConfig {
return ComputerConfig{
MaxParallelTasks: 10,
MaxSubTasks: 100,
TaskTimeout: 365 * 24 * time.Hour,
SubTaskTimeout: 2 * time.Hour,
TotalBudget: 100.0,
EnableSandbox: true,
EnableScheduling: true,
EnableBrowser: true,
SandboxImage: "gooseek/sandbox:latest",
BrowserServerURL: "http://browser-svc:3050",
CheckpointStorePath: "/data/checkpoints",
MaxConcurrentTasks: 50,
HeartbeatInterval: 30 * time.Second,
CheckpointInterval: 15 * time.Minute,
}
}
func GetDurationConfig(mode DurationMode) (maxDuration, checkpointFreq, heartbeatFreq time.Duration, maxIter int) {
cfg, ok := DurationModeConfigs[mode]
if !ok {
cfg = DurationModeConfigs[DurationMedium]
}
return cfg.MaxDuration, cfg.CheckpointFreq, cfg.HeartbeatFreq, cfg.MaxIterations
}
type Dependencies struct {
Registry *llm.ModelRegistry
TaskRepo TaskRepository
MemoryRepo MemoryRepository
ArtifactRepo ArtifactRepository
}
type TaskRepository interface {
Create(ctx context.Context, task *ComputerTask) error
Update(ctx context.Context, task *ComputerTask) error
GetByID(ctx context.Context, id string) (*ComputerTask, error)
GetByUserID(ctx context.Context, userID string, limit, offset int) ([]ComputerTask, error)
GetScheduled(ctx context.Context) ([]ComputerTask, error)
Delete(ctx context.Context, id string) error
}
type MemoryRepository interface {
Store(ctx context.Context, entry *MemoryEntry) error
GetByUser(ctx context.Context, userID string, limit int) ([]MemoryEntry, error)
GetByTask(ctx context.Context, taskID string) ([]MemoryEntry, error)
Search(ctx context.Context, userID, query string, limit int) ([]MemoryEntry, error)
Delete(ctx context.Context, id string) error
}
type ArtifactRepository interface {
Create(ctx context.Context, artifact *Artifact) error
GetByID(ctx context.Context, id string) (*Artifact, error)
GetByTaskID(ctx context.Context, taskID string) ([]Artifact, error)
Delete(ctx context.Context, id string) error
}
type Computer struct {
cfg ComputerConfig
planner *Planner
router *Router
executor *Executor
sandbox *SandboxManager
memory *MemoryStore
scheduler *Scheduler
connectors *connectors.ConnectorHub
registry *llm.ModelRegistry
taskRepo TaskRepository
eventBus *EventBus
mu sync.RWMutex
tasks map[string]*ComputerTask
}
func NewComputer(cfg ComputerConfig, deps Dependencies) *Computer {
eventBus := NewEventBus()
c := &Computer{
cfg: cfg,
registry: deps.Registry,
taskRepo: deps.TaskRepo,
eventBus: eventBus,
tasks: make(map[string]*ComputerTask),
}
c.planner = NewPlanner(deps.Registry)
c.router = NewRouter(deps.Registry)
c.executor = NewExecutor(c.router, cfg.MaxParallelTasks)
c.memory = NewMemoryStore(deps.MemoryRepo)
c.connectors = connectors.NewConnectorHub()
if cfg.EnableSandbox {
c.sandbox = NewSandboxManager(SandboxConfig{
Image: cfg.SandboxImage,
Timeout: cfg.SubTaskTimeout,
})
c.executor.SetSandbox(c.sandbox)
}
if cfg.EnableScheduling {
c.scheduler = NewScheduler(deps.TaskRepo, c)
}
return c
}
func (c *Computer) Execute(ctx context.Context, userID, query string, opts ExecuteOptions) (*ComputerTask, error) {
if opts.ResumeFromID != "" {
return c.resumeFromCheckpoint(ctx, opts.ResumeFromID, opts)
}
durationMode := opts.DurationMode
if durationMode == "" {
durationMode = DurationMedium
}
maxDuration, _, _, maxIter := GetDurationConfig(durationMode)
task := &ComputerTask{
ID: uuid.New().String(),
UserID: userID,
Query: query,
Status: StatusPending,
Memory: make(map[string]interface{}),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
DurationMode: durationMode,
MaxDuration: maxDuration,
MaxIterations: maxIter,
Priority: opts.Priority,
}
if opts.Priority == "" {
task.Priority = PriorityNormal
}
if opts.ResourceLimits != nil {
task.ResourceLimits = opts.ResourceLimits
}
if opts.Schedule != nil {
task.Schedule = opts.Schedule
task.Status = StatusScheduled
}
if opts.Context != nil {
task.Memory = opts.Context
}
estimatedEnd := time.Now().Add(maxDuration)
task.EstimatedEnd = &estimatedEnd
if err := c.taskRepo.Create(ctx, task); err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
c.mu.Lock()
c.tasks[task.ID] = task
c.mu.Unlock()
c.emitEvent(TaskEvent{
Type: EventTaskCreated,
TaskID: task.ID,
Status: task.Status,
Message: fmt.Sprintf("Task created (mode: %s, max duration: %v)", durationMode, maxDuration),
Timestamp: time.Now(),
Data: map[string]interface{}{
"durationMode": durationMode,
"maxDuration": maxDuration.String(),
"maxIterations": maxIter,
},
})
if opts.Async {
go c.executeTaskWithCheckpoints(context.Background(), task, opts)
return task, nil
}
return c.executeTaskWithCheckpoints(ctx, task, opts)
}
func (c *Computer) resumeFromCheckpoint(ctx context.Context, checkpointID string, opts ExecuteOptions) (*ComputerTask, error) {
task, err := c.taskRepo.GetByID(ctx, checkpointID)
if err != nil {
return nil, fmt.Errorf("task not found: %w", err)
}
if task.Checkpoint == nil {
return nil, errors.New("no checkpoint found for this task")
}
task.Status = StatusExecuting
now := time.Now()
task.ResumedAt = &now
task.UpdatedAt = now
c.emitEvent(TaskEvent{
Type: EventResumed,
TaskID: task.ID,
Status: task.Status,
Message: fmt.Sprintf("Resumed from checkpoint (wave: %d, subtask: %d)", task.Checkpoint.WaveIndex, task.Checkpoint.SubTaskIndex),
Progress: task.Checkpoint.Progress,
Timestamp: time.Now(),
})
c.mu.Lock()
c.tasks[task.ID] = task
c.mu.Unlock()
if opts.Async {
go c.executeTaskWithCheckpoints(context.Background(), task, opts)
return task, nil
}
return c.executeTaskWithCheckpoints(ctx, task, opts)
}
func (c *Computer) executeTask(ctx context.Context, task *ComputerTask, opts ExecuteOptions) (*ComputerTask, error) {
return c.executeTaskWithCheckpoints(ctx, task, opts)
}
func (c *Computer) executeTaskWithCheckpoints(ctx context.Context, task *ComputerTask, opts ExecuteOptions) (*ComputerTask, error) {
maxDuration, checkpointFreq, heartbeatFreq, _ := GetDurationConfig(task.DurationMode)
if opts.Timeout > 0 {
maxDuration = time.Duration(opts.Timeout) * time.Second
}
ctx, cancel := context.WithTimeout(ctx, maxDuration)
defer cancel()
budget := c.cfg.TotalBudget
if opts.MaxCost > 0 {
budget = opts.MaxCost
}
if task.ResourceLimits != nil && task.ResourceLimits.MaxTotalCost > 0 {
budget = task.ResourceLimits.MaxTotalCost
}
startWave := 0
if task.Checkpoint != nil {
startWave = task.Checkpoint.WaveIndex
for k, v := range task.Checkpoint.Memory {
task.Memory[k] = v
}
}
if task.Plan == nil {
task.Status = StatusPlanning
task.UpdatedAt = time.Now()
c.updateTask(ctx, task)
c.emitEvent(TaskEvent{
Type: EventTaskStarted,
TaskID: task.ID,
Status: StatusPlanning,
Message: "Planning task execution",
Timestamp: time.Now(),
})
userMemory, _ := c.memory.GetUserContext(ctx, task.UserID)
memoryContext := make(map[string]interface{})
for k, v := range userMemory {
memoryContext[k] = v
}
for k, v := range task.Memory {
memoryContext[k] = v
}
plan, err := c.planner.Plan(ctx, task.Query, memoryContext)
if err != nil {
task.Status = StatusFailed
task.Error = fmt.Sprintf("Planning failed: %v", err)
task.UpdatedAt = time.Now()
c.updateTask(ctx, task)
c.emitEvent(TaskEvent{
Type: EventTaskFailed,
TaskID: task.ID,
Status: StatusFailed,
Message: task.Error,
Timestamp: time.Now(),
})
return task, err
}
task.Plan = plan
task.SubTasks = plan.SubTasks
}
task.Status = StatusLongRunning
task.UpdatedAt = time.Now()
c.updateTask(ctx, task)
c.emitEvent(TaskEvent{
Type: EventTaskProgress,
TaskID: task.ID,
Status: StatusLongRunning,
Progress: 10,
Message: fmt.Sprintf("Executing %d subtasks (long-running mode)", len(task.Plan.SubTasks)),
Data: map[string]interface{}{
"plan": task.Plan,
"durationMode": task.DurationMode,
"checkpointFreq": checkpointFreq.String(),
},
Timestamp: time.Now(),
})
heartbeatTicker := time.NewTicker(heartbeatFreq)
defer heartbeatTicker.Stop()
checkpointTicker := time.NewTicker(checkpointFreq)
defer checkpointTicker.Stop()
go func() {
for {
select {
case <-ctx.Done():
return
case <-heartbeatTicker.C:
now := time.Now()
task.HeartbeatAt = &now
c.emitEvent(TaskEvent{
Type: EventHeartbeat,
TaskID: task.ID,
Progress: task.Progress,
Message: fmt.Sprintf("Heartbeat: %d%% complete, cost: $%.4f", task.Progress, task.TotalCost),
Data: map[string]interface{}{
"runtime": time.Since(task.CreatedAt).String(),
"cost": task.TotalCost,
},
Timestamp: now,
})
}
}
}()
totalSubTasks := len(task.Plan.ExecutionOrder)
for waveIdx := startWave; waveIdx < totalSubTasks; waveIdx++ {
select {
case <-ctx.Done():
c.saveCheckpoint(task, waveIdx, 0, "context_timeout")
return task, ctx.Err()
case <-checkpointTicker.C:
c.saveCheckpoint(task, waveIdx, 0, "periodic")
default:
}
if budget > 0 && task.TotalCost >= budget {
c.saveCheckpoint(task, waveIdx, 0, "budget_exceeded")
task.Status = StatusPaused
task.Message = fmt.Sprintf("Paused: budget exceeded ($%.2f / $%.2f)", task.TotalCost, budget)
c.updateTask(ctx, task)
return task, nil
}
wave := task.Plan.ExecutionOrder[waveIdx]
waveTasks := make([]SubTask, 0)
for _, subTaskID := range wave {
for i := range task.SubTasks {
if task.SubTasks[i].ID == subTaskID {
waveTasks = append(waveTasks, task.SubTasks[i])
break
}
}
}
results, err := c.executor.ExecuteGroup(ctx, waveTasks, budget-task.TotalCost)
if err != nil {
c.saveCheckpoint(task, waveIdx, 0, "execution_error")
task.Status = StatusFailed
task.Error = fmt.Sprintf("Execution failed at wave %d: %v", waveIdx, err)
task.UpdatedAt = time.Now()
c.updateTask(ctx, task)
return task, err
}
for _, result := range results {
for i := range task.SubTasks {
if task.SubTasks[i].ID == result.SubTaskID {
task.SubTasks[i].Output = result.Output
task.SubTasks[i].Cost = result.Cost
task.SubTasks[i].Status = StatusCompleted
now := time.Now()
task.SubTasks[i].CompletedAt = &now
if result.Error != nil {
task.SubTasks[i].Status = StatusFailed
task.SubTasks[i].Error = result.Error.Error()
}
break
}
}
task.TotalCost += result.Cost
task.TotalRuntime = time.Since(task.CreatedAt)
for _, artifact := range result.Artifacts {
task.Artifacts = append(task.Artifacts, artifact)
c.emitEvent(TaskEvent{
Type: EventArtifact,
TaskID: task.ID,
SubTaskID: result.SubTaskID,
Data: map[string]interface{}{
"artifact": artifact,
},
Timestamp: time.Now(),
})
}
}
progress := 10 + int(float64(waveIdx+1)/float64(totalSubTasks)*80)
task.Progress = progress
task.Iterations = waveIdx + 1
task.UpdatedAt = time.Now()
c.updateTask(ctx, task)
c.emitEvent(TaskEvent{
Type: EventIteration,
TaskID: task.ID,
Progress: progress,
Message: fmt.Sprintf("Completed wave %d/%d (runtime: %v)", waveIdx+1, totalSubTasks, time.Since(task.CreatedAt).Round(time.Second)),
Data: map[string]interface{}{
"wave": waveIdx + 1,
"total": totalSubTasks,
"cost": task.TotalCost,
"runtime": time.Since(task.CreatedAt).String(),
"artifacts": len(task.Artifacts),
},
Timestamp: time.Now(),
})
}
task.Status = StatusCompleted
task.Progress = 100
now := time.Now()
task.CompletedAt = &now
task.UpdatedAt = now
task.TotalRuntime = time.Since(task.CreatedAt)
c.updateTask(ctx, task)
c.emitEvent(TaskEvent{
Type: EventTaskCompleted,
TaskID: task.ID,
Status: StatusCompleted,
Progress: 100,
Message: fmt.Sprintf("Task completed (runtime: %v, cost: $%.4f)", task.TotalRuntime.Round(time.Second), task.TotalCost),
Data: map[string]interface{}{
"artifacts": task.Artifacts,
"totalCost": task.TotalCost,
"totalRuntime": task.TotalRuntime.String(),
"iterations": task.Iterations,
},
Timestamp: time.Now(),
})
c.storeTaskResults(ctx, task)
return task, nil
}
func (c *Computer) saveCheckpoint(task *ComputerTask, waveIdx, subTaskIdx int, reason string) {
checkpoint := Checkpoint{
ID: uuid.New().String(),
TaskID: task.ID,
WaveIndex: waveIdx,
SubTaskIndex: subTaskIdx,
State: make(map[string]interface{}),
Progress: task.Progress,
Memory: task.Memory,
CreatedAt: time.Now(),
RuntimeSoFar: time.Since(task.CreatedAt),
CostSoFar: task.TotalCost,
Reason: reason,
}
for _, artifact := range task.Artifacts {
checkpoint.Artifacts = append(checkpoint.Artifacts, artifact.ID)
}
task.Checkpoint = &checkpoint
task.Checkpoints = append(task.Checkpoints, checkpoint)
task.UpdatedAt = time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
c.taskRepo.Update(ctx, task)
c.emitEvent(TaskEvent{
Type: EventCheckpointSaved,
TaskID: task.ID,
Progress: task.Progress,
Message: fmt.Sprintf("Checkpoint saved: %s (wave %d)", reason, waveIdx),
Data: map[string]interface{}{
"checkpointId": checkpoint.ID,
"waveIndex": waveIdx,
"subTaskIndex": subTaskIdx,
"reason": reason,
"runtime": checkpoint.RuntimeSoFar.String(),
"cost": checkpoint.CostSoFar,
},
Timestamp: time.Now(),
})
}
func (c *Computer) Pause(ctx context.Context, taskID string) error {
c.mu.Lock()
task, ok := c.tasks[taskID]
if !ok {
c.mu.Unlock()
var err error
task, err = c.taskRepo.GetByID(ctx, taskID)
if err != nil {
return err
}
c.mu.Lock()
}
if task.Status != StatusExecuting && task.Status != StatusLongRunning {
c.mu.Unlock()
return errors.New("task is not running")
}
now := time.Now()
task.Status = StatusPaused
task.PausedAt = &now
task.UpdatedAt = now
c.mu.Unlock()
c.saveCheckpoint(task, task.Iterations, 0, "user_paused")
c.emitEvent(TaskEvent{
Type: EventPaused,
TaskID: taskID,
Status: StatusPaused,
Progress: task.Progress,
Message: "Task paused by user",
Timestamp: now,
})
return c.taskRepo.Update(ctx, task)
}
func (c *Computer) Resume(ctx context.Context, taskID string, userInput string) error {
c.mu.RLock()
task, ok := c.tasks[taskID]
c.mu.RUnlock()
if !ok {
var err error
task, err = c.taskRepo.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("task not found: %w", err)
}
}
if task.Status != StatusWaiting {
return errors.New("task is not waiting for user input")
}
task.Memory["user_input"] = userInput
task.Status = StatusExecuting
task.UpdatedAt = time.Now()
go c.executeTask(context.Background(), task, ExecuteOptions{Async: true})
return nil
}
func (c *Computer) Cancel(ctx context.Context, taskID string) error {
c.mu.Lock()
task, ok := c.tasks[taskID]
if ok {
task.Status = StatusCancelled
task.UpdatedAt = time.Now()
}
c.mu.Unlock()
if !ok {
task, err := c.taskRepo.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("task not found: %w", err)
}
task.Status = StatusCancelled
task.UpdatedAt = time.Now()
return c.taskRepo.Update(ctx, task)
}
c.emitEvent(TaskEvent{
Type: EventTaskFailed,
TaskID: taskID,
Status: StatusCancelled,
Message: "Task cancelled by user",
Timestamp: time.Now(),
})
return c.taskRepo.Update(ctx, task)
}
func (c *Computer) GetStatus(ctx context.Context, taskID string) (*ComputerTask, error) {
c.mu.RLock()
task, ok := c.tasks[taskID]
c.mu.RUnlock()
if ok {
return task, nil
}
return c.taskRepo.GetByID(ctx, taskID)
}
func (c *Computer) GetUserTasks(ctx context.Context, userID string, limit, offset int) ([]ComputerTask, error) {
return c.taskRepo.GetByUserID(ctx, userID, limit, offset)
}
func (c *Computer) Stream(ctx context.Context, taskID string) (<-chan TaskEvent, error) {
return c.eventBus.Subscribe(taskID), nil
}
func (c *Computer) updateTask(ctx context.Context, task *ComputerTask) {
c.mu.Lock()
c.tasks[task.ID] = task
c.mu.Unlock()
_ = c.taskRepo.Update(ctx, task)
}
func (c *Computer) emitEvent(event TaskEvent) {
c.eventBus.Publish(event.TaskID, event)
}
func (c *Computer) storeTaskResults(ctx context.Context, task *ComputerTask) {
for _, st := range task.SubTasks {
if st.Output != nil {
outputJSON, _ := json.Marshal(st.Output)
entry := &MemoryEntry{
ID: uuid.New().String(),
UserID: task.UserID,
TaskID: task.ID,
Key: fmt.Sprintf("subtask_%s_result", st.ID),
Value: string(outputJSON),
Type: MemoryTypeResult,
CreatedAt: time.Now(),
}
_ = c.memory.Store(ctx, task.UserID, entry)
}
}
}
func (c *Computer) StartScheduler(ctx context.Context) {
if c.scheduler != nil {
c.scheduler.Start(ctx)
}
}
func (c *Computer) StopScheduler() {
if c.scheduler != nil {
c.scheduler.Stop()
}
}
type EventBus struct {
subscribers map[string][]chan TaskEvent
mu sync.RWMutex
}
func NewEventBus() *EventBus {
return &EventBus{
subscribers: make(map[string][]chan TaskEvent),
}
}
func (eb *EventBus) Subscribe(taskID string) <-chan TaskEvent {
eb.mu.Lock()
defer eb.mu.Unlock()
ch := make(chan TaskEvent, 100)
eb.subscribers[taskID] = append(eb.subscribers[taskID], ch)
return ch
}
func (eb *EventBus) Unsubscribe(taskID string, ch <-chan TaskEvent) {
eb.mu.Lock()
defer eb.mu.Unlock()
subs := eb.subscribers[taskID]
for i, sub := range subs {
if sub == ch {
eb.subscribers[taskID] = append(subs[:i], subs[i+1:]...)
close(sub)
break
}
}
}
func (eb *EventBus) Publish(taskID string, event TaskEvent) {
eb.mu.RLock()
subs := eb.subscribers[taskID]
eb.mu.RUnlock()
for _, ch := range subs {
select {
case ch <- event:
default:
}
}
}

View File

@@ -0,0 +1,104 @@
package connectors
import (
"context"
"errors"
"sync"
)
type Connector interface {
ID() string
Name() string
Description() string
Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error)
GetActions() []Action
Validate(params map[string]interface{}) error
}
type Action struct {
Name string `json:"name"`
Description string `json:"description"`
Schema map[string]interface{} `json:"schema"`
Required []string `json:"required"`
}
type ConnectorHub struct {
connectors map[string]Connector
mu sync.RWMutex
}
func NewConnectorHub() *ConnectorHub {
return &ConnectorHub{
connectors: make(map[string]Connector),
}
}
func (h *ConnectorHub) Register(connector Connector) {
h.mu.Lock()
defer h.mu.Unlock()
h.connectors[connector.ID()] = connector
}
func (h *ConnectorHub) Unregister(id string) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.connectors, id)
}
func (h *ConnectorHub) Get(id string) (Connector, error) {
h.mu.RLock()
defer h.mu.RUnlock()
connector, ok := h.connectors[id]
if !ok {
return nil, errors.New("connector not found: " + id)
}
return connector, nil
}
func (h *ConnectorHub) List() []Connector {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]Connector, 0, len(h.connectors))
for _, c := range h.connectors {
result = append(result, c)
}
return result
}
func (h *ConnectorHub) Execute(ctx context.Context, connectorID, action string, params map[string]interface{}) (interface{}, error) {
connector, err := h.Get(connectorID)
if err != nil {
return nil, err
}
if err := connector.Validate(params); err != nil {
return nil, err
}
return connector.Execute(ctx, action, params)
}
type ConnectorInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Actions []Action `json:"actions"`
}
func (h *ConnectorHub) GetInfo() []ConnectorInfo {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]ConnectorInfo, 0, len(h.connectors))
for _, c := range h.connectors {
result = append(result, ConnectorInfo{
ID: c.ID(),
Name: c.Name(),
Description: c.Description(),
Actions: c.GetActions(),
})
}
return result
}

View File

@@ -0,0 +1,215 @@
package connectors
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/smtp"
"strings"
)
type EmailConfig struct {
SMTPHost string
SMTPPort int
Username string
Password string
FromAddress string
FromName string
UseTLS bool
AllowHTML bool
}
type EmailConnector struct {
cfg EmailConfig
}
func NewEmailConnector(cfg EmailConfig) *EmailConnector {
return &EmailConnector{cfg: cfg}
}
func (e *EmailConnector) ID() string {
return "email"
}
func (e *EmailConnector) Name() string {
return "Email"
}
func (e *EmailConnector) Description() string {
return "Send emails via SMTP"
}
func (e *EmailConnector) GetActions() []Action {
return []Action{
{
Name: "send",
Description: "Send an email",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"to": map[string]interface{}{"type": "string", "description": "Recipient email address"},
"subject": map[string]interface{}{"type": "string", "description": "Email subject"},
"body": map[string]interface{}{"type": "string", "description": "Email body"},
"html": map[string]interface{}{"type": "boolean", "description": "Whether body is HTML"},
"cc": map[string]interface{}{"type": "string", "description": "CC recipients (comma-separated)"},
"bcc": map[string]interface{}{"type": "string", "description": "BCC recipients (comma-separated)"},
},
},
Required: []string{"to", "subject", "body"},
},
}
}
func (e *EmailConnector) Validate(params map[string]interface{}) error {
if _, ok := params["to"]; !ok {
return errors.New("'to' is required")
}
if _, ok := params["subject"]; !ok {
return errors.New("'subject' is required")
}
if _, ok := params["body"]; !ok {
return errors.New("'body' is required")
}
return nil
}
func (e *EmailConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
switch action {
case "send":
return e.send(ctx, params)
default:
return nil, errors.New("unknown action: " + action)
}
}
func (e *EmailConnector) send(ctx context.Context, params map[string]interface{}) (interface{}, error) {
to := params["to"].(string)
subject := params["subject"].(string)
body := params["body"].(string)
isHTML := false
if html, ok := params["html"].(bool); ok {
isHTML = html && e.cfg.AllowHTML
}
var cc, bcc []string
if ccStr, ok := params["cc"].(string); ok && ccStr != "" {
cc = strings.Split(ccStr, ",")
for i := range cc {
cc[i] = strings.TrimSpace(cc[i])
}
}
if bccStr, ok := params["bcc"].(string); ok && bccStr != "" {
bcc = strings.Split(bccStr, ",")
for i := range bcc {
bcc[i] = strings.TrimSpace(bcc[i])
}
}
from := e.cfg.FromAddress
if e.cfg.FromName != "" {
from = fmt.Sprintf("%s <%s>", e.cfg.FromName, e.cfg.FromAddress)
}
var msg strings.Builder
msg.WriteString(fmt.Sprintf("From: %s\r\n", from))
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
if len(cc) > 0 {
msg.WriteString(fmt.Sprintf("Cc: %s\r\n", strings.Join(cc, ", ")))
}
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
msg.WriteString("MIME-Version: 1.0\r\n")
if isHTML {
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
} else {
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
}
msg.WriteString("\r\n")
msg.WriteString(body)
recipients := []string{to}
recipients = append(recipients, cc...)
recipients = append(recipients, bcc...)
addr := fmt.Sprintf("%s:%d", e.cfg.SMTPHost, e.cfg.SMTPPort)
var auth smtp.Auth
if e.cfg.Username != "" && e.cfg.Password != "" {
auth = smtp.PlainAuth("", e.cfg.Username, e.cfg.Password, e.cfg.SMTPHost)
}
var err error
if e.cfg.UseTLS {
err = e.sendWithTLS(addr, auth, e.cfg.FromAddress, recipients, []byte(msg.String()))
} else {
err = smtp.SendMail(addr, auth, e.cfg.FromAddress, recipients, []byte(msg.String()))
}
if err != nil {
return map[string]interface{}{
"success": false,
"error": err.Error(),
}, err
}
return map[string]interface{}{
"success": true,
"to": to,
"subject": subject,
"recipients": len(recipients),
}, nil
}
func (e *EmailConnector) sendWithTLS(addr string, auth smtp.Auth, from string, to []string, msg []byte) error {
tlsConfig := &tls.Config{
ServerName: e.cfg.SMTPHost,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return err
}
defer conn.Close()
client, err := smtp.NewClient(conn, e.cfg.SMTPHost)
if err != nil {
return err
}
defer client.Close()
if auth != nil {
if err := client.Auth(auth); err != nil {
return err
}
}
if err := client.Mail(from); err != nil {
return err
}
for _, recipient := range to {
if err := client.Rcpt(recipient); err != nil {
return err
}
}
w, err := client.Data()
if err != nil {
return err
}
_, err = w.Write(msg)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return client.Quit()
}

View File

@@ -0,0 +1,432 @@
package connectors
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type StorageConfig struct {
Endpoint string
AccessKeyID string
SecretAccessKey string
BucketName string
UseSSL bool
Region string
PublicURL string
}
type StorageConnector struct {
cfg StorageConfig
client *minio.Client
}
func NewStorageConnector(cfg StorageConfig) (*StorageConnector, error) {
client, err := minio.New(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
Secure: cfg.UseSSL,
Region: cfg.Region,
})
if err != nil {
return nil, fmt.Errorf("failed to create storage client: %w", err)
}
return &StorageConnector{
cfg: cfg,
client: client,
}, nil
}
func (s *StorageConnector) ID() string {
return "storage"
}
func (s *StorageConnector) Name() string {
return "Storage"
}
func (s *StorageConnector) Description() string {
return "Store and retrieve files from S3-compatible storage"
}
func (s *StorageConnector) GetActions() []Action {
return []Action{
{
Name: "upload",
Description: "Upload a file",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
"content": map[string]interface{}{"type": "string", "description": "File content (base64 or text)"},
"content_type": map[string]interface{}{"type": "string", "description": "MIME type"},
"public": map[string]interface{}{"type": "boolean", "description": "Make file publicly accessible"},
},
},
Required: []string{"path", "content"},
},
{
Name: "download",
Description: "Download a file",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
},
},
Required: []string{"path"},
},
{
Name: "delete",
Description: "Delete a file",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
},
},
Required: []string{"path"},
},
{
Name: "list",
Description: "List files in a directory",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"prefix": map[string]interface{}{"type": "string", "description": "Path prefix"},
"limit": map[string]interface{}{"type": "integer", "description": "Max results"},
},
},
},
{
Name: "get_url",
Description: "Get a presigned URL for a file",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
"expires": map[string]interface{}{"type": "integer", "description": "URL expiry in seconds"},
},
},
Required: []string{"path"},
},
}
}
func (s *StorageConnector) Validate(params map[string]interface{}) error {
return nil
}
func (s *StorageConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
switch action {
case "upload":
return s.upload(ctx, params)
case "download":
return s.download(ctx, params)
case "delete":
return s.deleteFile(ctx, params)
case "list":
return s.list(ctx, params)
case "get_url":
return s.getURL(ctx, params)
default:
return nil, errors.New("unknown action: " + action)
}
}
func (s *StorageConnector) upload(ctx context.Context, params map[string]interface{}) (interface{}, error) {
path := params["path"].(string)
content := params["content"].(string)
contentType := "application/octet-stream"
if ct, ok := params["content_type"].(string); ok {
contentType = ct
}
if contentType == "" {
contentType = s.detectContentType(path)
}
reader := bytes.NewReader([]byte(content))
size := int64(len(content))
info, err := s.client.PutObject(ctx, s.cfg.BucketName, path, reader, size, minio.PutObjectOptions{
ContentType: contentType,
})
if err != nil {
return nil, fmt.Errorf("upload failed: %w", err)
}
url := ""
if s.cfg.PublicURL != "" {
url = fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path)
}
return map[string]interface{}{
"success": true,
"path": path,
"size": info.Size,
"etag": info.ETag,
"url": url,
}, nil
}
func (s *StorageConnector) UploadBytes(ctx context.Context, path string, content []byte, contentType string) (string, error) {
if contentType == "" {
contentType = s.detectContentType(path)
}
reader := bytes.NewReader(content)
size := int64(len(content))
_, err := s.client.PutObject(ctx, s.cfg.BucketName, path, reader, size, minio.PutObjectOptions{
ContentType: contentType,
})
if err != nil {
return "", err
}
if s.cfg.PublicURL != "" {
return fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path), nil
}
return path, nil
}
func (s *StorageConnector) download(ctx context.Context, params map[string]interface{}) (interface{}, error) {
path := params["path"].(string)
obj, err := s.client.GetObject(ctx, s.cfg.BucketName, path, minio.GetObjectOptions{})
if err != nil {
return nil, fmt.Errorf("download failed: %w", err)
}
defer obj.Close()
content, err := io.ReadAll(obj)
if err != nil {
return nil, fmt.Errorf("read failed: %w", err)
}
stat, _ := obj.Stat()
return map[string]interface{}{
"success": true,
"path": path,
"content": string(content),
"size": len(content),
"content_type": stat.ContentType,
"modified": stat.LastModified,
}, nil
}
func (s *StorageConnector) DownloadBytes(ctx context.Context, path string) ([]byte, error) {
obj, err := s.client.GetObject(ctx, s.cfg.BucketName, path, minio.GetObjectOptions{})
if err != nil {
return nil, err
}
defer obj.Close()
return io.ReadAll(obj)
}
func (s *StorageConnector) deleteFile(ctx context.Context, params map[string]interface{}) (interface{}, error) {
path := params["path"].(string)
err := s.client.RemoveObject(ctx, s.cfg.BucketName, path, minio.RemoveObjectOptions{})
if err != nil {
return nil, fmt.Errorf("delete failed: %w", err)
}
return map[string]interface{}{
"success": true,
"path": path,
}, nil
}
func (s *StorageConnector) list(ctx context.Context, params map[string]interface{}) (interface{}, error) {
prefix := ""
if p, ok := params["prefix"].(string); ok {
prefix = p
}
limit := 100
if l, ok := params["limit"].(float64); ok {
limit = int(l)
}
objects := s.client.ListObjects(ctx, s.cfg.BucketName, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: true,
})
var files []map[string]interface{}
count := 0
for obj := range objects {
if obj.Err != nil {
continue
}
files = append(files, map[string]interface{}{
"path": obj.Key,
"size": obj.Size,
"modified": obj.LastModified,
"etag": obj.ETag,
})
count++
if count >= limit {
break
}
}
return map[string]interface{}{
"success": true,
"files": files,
"count": len(files),
}, nil
}
func (s *StorageConnector) getURL(ctx context.Context, params map[string]interface{}) (interface{}, error) {
path := params["path"].(string)
expires := 3600
if e, ok := params["expires"].(float64); ok {
expires = int(e)
}
url, err := s.client.PresignedGetObject(ctx, s.cfg.BucketName, path, time.Duration(expires)*time.Second, nil)
if err != nil {
return nil, fmt.Errorf("failed to generate URL: %w", err)
}
return map[string]interface{}{
"success": true,
"url": url.String(),
"expires": expires,
}, nil
}
func (s *StorageConnector) GetPublicURL(path string) string {
if s.cfg.PublicURL != "" {
return fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path)
}
return ""
}
func (s *StorageConnector) detectContentType(path string) string {
ext := strings.ToLower(filepath.Ext(path))
contentTypes := map[string]string{
".html": "text/html",
".css": "text/css",
".js": "application/javascript",
".json": "application/json",
".xml": "application/xml",
".pdf": "application/pdf",
".zip": "application/zip",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".mp4": "video/mp4",
".mp3": "audio/mpeg",
".txt": "text/plain",
".md": "text/markdown",
".csv": "text/csv",
".py": "text/x-python",
".go": "text/x-go",
".rs": "text/x-rust",
}
if ct, ok := contentTypes[ext]; ok {
return ct
}
return "application/octet-stream"
}
func (s *StorageConnector) EnsureBucket(ctx context.Context) error {
exists, err := s.client.BucketExists(ctx, s.cfg.BucketName)
if err != nil {
return err
}
if !exists {
return s.client.MakeBucket(ctx, s.cfg.BucketName, minio.MakeBucketOptions{
Region: s.cfg.Region,
})
}
return nil
}
func NewLocalStorageConnector(basePath string) *LocalStorageConnector {
return &LocalStorageConnector{basePath: basePath}
}
type LocalStorageConnector struct {
basePath string
}
func (l *LocalStorageConnector) ID() string {
return "local_storage"
}
func (l *LocalStorageConnector) Name() string {
return "Local Storage"
}
func (l *LocalStorageConnector) Description() string {
return "Store files on local filesystem"
}
func (l *LocalStorageConnector) GetActions() []Action {
return []Action{
{Name: "upload", Description: "Upload a file"},
{Name: "download", Description: "Download a file"},
{Name: "delete", Description: "Delete a file"},
{Name: "list", Description: "List files"},
}
}
func (l *LocalStorageConnector) Validate(params map[string]interface{}) error {
return nil
}
func (l *LocalStorageConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
switch action {
case "upload":
path := params["path"].(string)
content := params["content"].(string)
fullPath := filepath.Join(l.basePath, path)
os.MkdirAll(filepath.Dir(fullPath), 0755)
err := os.WriteFile(fullPath, []byte(content), 0644)
return map[string]interface{}{"success": err == nil, "path": path}, err
case "download":
path := params["path"].(string)
fullPath := filepath.Join(l.basePath, path)
content, err := os.ReadFile(fullPath)
return map[string]interface{}{"success": err == nil, "content": string(content)}, err
case "delete":
path := params["path"].(string)
fullPath := filepath.Join(l.basePath, path)
err := os.Remove(fullPath)
return map[string]interface{}{"success": err == nil}, err
default:
return nil, errors.New("unknown action")
}
}

View File

@@ -0,0 +1,263 @@
package connectors
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"strconv"
"time"
)
type TelegramConfig struct {
BotToken string
Timeout time.Duration
}
type TelegramConnector struct {
cfg TelegramConfig
client *http.Client
}
func NewTelegramConnector(cfg TelegramConfig) *TelegramConnector {
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
return &TelegramConnector{
cfg: cfg,
client: &http.Client{
Timeout: timeout,
},
}
}
func (t *TelegramConnector) ID() string {
return "telegram"
}
func (t *TelegramConnector) Name() string {
return "Telegram"
}
func (t *TelegramConnector) Description() string {
return "Send messages via Telegram Bot API"
}
func (t *TelegramConnector) GetActions() []Action {
return []Action{
{
Name: "send_message",
Description: "Send a text message",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID or @username"},
"text": map[string]interface{}{"type": "string", "description": "Message text"},
"parse_mode": map[string]interface{}{"type": "string", "enum": []string{"HTML", "Markdown", "MarkdownV2"}},
},
},
Required: []string{"chat_id", "text"},
},
{
Name: "send_document",
Description: "Send a document/file",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID"},
"document": map[string]interface{}{"type": "string", "description": "File path or URL"},
"caption": map[string]interface{}{"type": "string", "description": "Document caption"},
},
},
Required: []string{"chat_id", "document"},
},
{
Name: "send_photo",
Description: "Send a photo",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID"},
"photo": map[string]interface{}{"type": "string", "description": "Photo URL or file_id"},
"caption": map[string]interface{}{"type": "string", "description": "Photo caption"},
},
},
Required: []string{"chat_id", "photo"},
},
}
}
func (t *TelegramConnector) Validate(params map[string]interface{}) error {
if _, ok := params["chat_id"]; !ok {
return errors.New("'chat_id' is required")
}
return nil
}
func (t *TelegramConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
switch action {
case "send_message":
return t.sendMessage(ctx, params)
case "send_document":
return t.sendDocument(ctx, params)
case "send_photo":
return t.sendPhoto(ctx, params)
default:
return nil, errors.New("unknown action: " + action)
}
}
func (t *TelegramConnector) sendMessage(ctx context.Context, params map[string]interface{}) (interface{}, error) {
chatID := params["chat_id"].(string)
text := params["text"].(string)
payload := map[string]interface{}{
"chat_id": chatID,
"text": text,
}
if parseMode, ok := params["parse_mode"].(string); ok {
payload["parse_mode"] = parseMode
}
return t.apiCall(ctx, "sendMessage", payload)
}
func (t *TelegramConnector) sendDocument(ctx context.Context, params map[string]interface{}) (interface{}, error) {
chatID := params["chat_id"].(string)
document := params["document"].(string)
payload := map[string]interface{}{
"chat_id": chatID,
"document": document,
}
if caption, ok := params["caption"].(string); ok {
payload["caption"] = caption
}
return t.apiCall(ctx, "sendDocument", payload)
}
func (t *TelegramConnector) sendPhoto(ctx context.Context, params map[string]interface{}) (interface{}, error) {
chatID := params["chat_id"].(string)
photo := params["photo"].(string)
payload := map[string]interface{}{
"chat_id": chatID,
"photo": photo,
}
if caption, ok := params["caption"].(string); ok {
payload["caption"] = caption
}
return t.apiCall(ctx, "sendPhoto", payload)
}
func (t *TelegramConnector) apiCall(ctx context.Context, method string, payload map[string]interface{}) (interface{}, error) {
url := fmt.Sprintf("https://api.telegram.org/bot%s/%s", t.cfg.BotToken, method)
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
if ok, exists := result["ok"].(bool); exists && !ok {
desc := "unknown error"
if d, exists := result["description"].(string); exists {
desc = d
}
return result, errors.New("Telegram API error: " + desc)
}
return result, nil
}
func (t *TelegramConnector) SendFileFromBytes(ctx context.Context, chatID string, filename string, content []byte, caption string) (interface{}, error) {
url := fmt.Sprintf("https://api.telegram.org/bot%s/sendDocument", t.cfg.BotToken)
var b bytes.Buffer
w := multipart.NewWriter(&b)
w.WriteField("chat_id", chatID)
if caption != "" {
w.WriteField("caption", caption)
}
fw, err := w.CreateFormFile("document", filename)
if err != nil {
return nil, err
}
fw.Write(content)
w.Close()
req, err := http.NewRequestWithContext(ctx, "POST", url, &b)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", w.FormDataContentType())
resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return result, nil
}
func (t *TelegramConnector) GetChatID(chatIDOrUsername interface{}) string {
switch v := chatIDOrUsername.(type) {
case string:
return v
case int:
return strconv.Itoa(v)
case int64:
return strconv.FormatInt(v, 10)
case float64:
return strconv.FormatInt(int64(v), 10)
default:
return fmt.Sprintf("%v", v)
}
}

View File

@@ -0,0 +1,275 @@
package connectors
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type WebhookConfig struct {
Timeout time.Duration
MaxRetries int
RetryDelay time.Duration
DefaultSecret string
}
type WebhookConnector struct {
cfg WebhookConfig
client *http.Client
}
func NewWebhookConnector(cfg WebhookConfig) *WebhookConnector {
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.RetryDelay == 0 {
cfg.RetryDelay = time.Second
}
return &WebhookConnector{
cfg: cfg,
client: &http.Client{
Timeout: timeout,
},
}
}
func (w *WebhookConnector) ID() string {
return "webhook"
}
func (w *WebhookConnector) Name() string {
return "Webhook"
}
func (w *WebhookConnector) Description() string {
return "Send HTTP webhooks to external services"
}
func (w *WebhookConnector) GetActions() []Action {
return []Action{
{
Name: "post",
Description: "Send POST request",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "Webhook URL"},
"body": map[string]interface{}{"type": "object", "description": "Request body (JSON)"},
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
"secret": map[string]interface{}{"type": "string", "description": "HMAC secret for signing"},
},
},
Required: []string{"url"},
},
{
Name: "get",
Description: "Send GET request",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "Request URL"},
"params": map[string]interface{}{"type": "object", "description": "Query parameters"},
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
},
},
Required: []string{"url"},
},
{
Name: "put",
Description: "Send PUT request",
Schema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "Request URL"},
"body": map[string]interface{}{"type": "object", "description": "Request body (JSON)"},
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
},
},
Required: []string{"url"},
},
}
}
func (w *WebhookConnector) Validate(params map[string]interface{}) error {
urlStr, ok := params["url"].(string)
if !ok {
return errors.New("'url' is required")
}
parsed, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("URL must use http or https scheme")
}
return nil
}
func (w *WebhookConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
switch action {
case "post":
return w.doRequest(ctx, "POST", params)
case "get":
return w.doRequest(ctx, "GET", params)
case "put":
return w.doRequest(ctx, "PUT", params)
case "delete":
return w.doRequest(ctx, "DELETE", params)
case "patch":
return w.doRequest(ctx, "PATCH", params)
default:
return nil, errors.New("unknown action: " + action)
}
}
func (w *WebhookConnector) doRequest(ctx context.Context, method string, params map[string]interface{}) (interface{}, error) {
urlStr := params["url"].(string)
if method == "GET" {
if queryParams, ok := params["params"].(map[string]interface{}); ok {
parsedURL, _ := url.Parse(urlStr)
q := parsedURL.Query()
for k, v := range queryParams {
q.Set(k, fmt.Sprintf("%v", v))
}
parsedURL.RawQuery = q.Encode()
urlStr = parsedURL.String()
}
}
var bodyReader io.Reader
var bodyBytes []byte
if body, ok := params["body"]; ok && method != "GET" {
var err error
bodyBytes, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal body: %w", err)
}
bodyReader = bytes.NewReader(bodyBytes)
}
var lastErr error
for attempt := 0; attempt <= w.cfg.MaxRetries; attempt++ {
if attempt > 0 {
time.Sleep(w.cfg.RetryDelay * time.Duration(attempt))
if bodyBytes != nil {
bodyReader = bytes.NewReader(bodyBytes)
}
}
req, err := http.NewRequestWithContext(ctx, method, urlStr, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "GooSeek-Computer/1.0")
if headers, ok := params["headers"].(map[string]interface{}); ok {
for k, v := range headers {
req.Header.Set(k, fmt.Sprintf("%v", v))
}
}
if bodyBytes != nil {
secret := w.cfg.DefaultSecret
if s, ok := params["secret"].(string); ok {
secret = s
}
if secret != "" {
signature := w.signPayload(bodyBytes, secret)
req.Header.Set("X-Signature-256", "sha256="+signature)
}
}
resp, err := w.client.Do(req)
if err != nil {
lastErr = err
continue
}
respBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = err
continue
}
result := map[string]interface{}{
"status_code": resp.StatusCode,
"headers": w.headersToMap(resp.Header),
}
var jsonBody interface{}
if err := json.Unmarshal(respBody, &jsonBody); err == nil {
result["body"] = jsonBody
} else {
result["body"] = string(respBody)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
result["success"] = true
return result, nil
}
if resp.StatusCode >= 500 {
lastErr = fmt.Errorf("server error: %d", resp.StatusCode)
continue
}
result["success"] = false
return result, nil
}
return map[string]interface{}{
"success": false,
"error": lastErr.Error(),
}, lastErr
}
func (w *WebhookConnector) signPayload(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
func (w *WebhookConnector) headersToMap(headers http.Header) map[string]string {
result := make(map[string]string)
for k, v := range headers {
result[k] = strings.Join(v, ", ")
}
return result
}
func (w *WebhookConnector) PostJSON(ctx context.Context, webhookURL string, data interface{}) (interface{}, error) {
return w.Execute(ctx, "post", map[string]interface{}{
"url": webhookURL,
"body": data,
})
}
func (w *WebhookConnector) GetJSON(ctx context.Context, webhookURL string, params map[string]interface{}) (interface{}, error) {
return w.Execute(ctx, "get", map[string]interface{}{
"url": webhookURL,
"params": params,
})
}

View File

@@ -0,0 +1,574 @@
package computer
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
)
type Executor struct {
router *Router
sandbox *SandboxManager
maxWorkers int
}
func NewExecutor(router *Router, maxWorkers int) *Executor {
if maxWorkers <= 0 {
maxWorkers = 5
}
return &Executor{
router: router,
maxWorkers: maxWorkers,
}
}
func (e *Executor) SetSandbox(sandbox *SandboxManager) {
e.sandbox = sandbox
}
func (e *Executor) ExecuteGroup(ctx context.Context, tasks []SubTask, budget float64) ([]ExecutionResult, error) {
results := make([]ExecutionResult, len(tasks))
var mu sync.Mutex
perTaskBudget := budget / float64(len(tasks))
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(e.maxWorkers)
for i, task := range tasks {
i, task := i, task
g.Go(func() error {
result, err := e.ExecuteTask(gctx, &task, perTaskBudget)
mu.Lock()
if err != nil {
results[i] = ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Error: err,
}
} else {
results[i] = *result
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return results, err
}
return results, nil
}
func (e *Executor) ExecuteTask(ctx context.Context, task *SubTask, budget float64) (*ExecutionResult, error) {
startTime := time.Now()
client, spec, err := e.router.Route(task, budget)
if err != nil {
return nil, fmt.Errorf("routing failed: %w", err)
}
task.ModelID = spec.ID
now := time.Now()
task.StartedAt = &now
var result *ExecutionResult
switch task.Type {
case TaskResearch:
result, err = e.executeResearch(ctx, client, task)
case TaskCode:
result, err = e.executeCode(ctx, client, task)
case TaskAnalysis:
result, err = e.executeAnalysis(ctx, client, task)
case TaskDesign:
result, err = e.executeDesign(ctx, client, task)
case TaskDeploy:
result, err = e.executeDeploy(ctx, client, task)
case TaskReport:
result, err = e.executeReport(ctx, client, task)
case TaskCommunicate:
result, err = e.executeCommunicate(ctx, client, task)
case TaskTransform:
result, err = e.executeTransform(ctx, client, task)
case TaskValidate:
result, err = e.executeValidate(ctx, client, task)
default:
result, err = e.executeGeneric(ctx, client, task)
}
if err != nil {
return nil, err
}
result.Duration = time.Since(startTime)
result.Cost = e.router.EstimateCost(task, 1000, 500)
return result, nil
}
func (e *Executor) executeResearch(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
prompt := fmt.Sprintf(`You are a research assistant. Complete this research task:
Task: %s
Additional context: %v
Provide a comprehensive research result with:
1. Key findings
2. Sources/references
3. Summary
Respond in JSON:
{
"findings": ["finding 1", "finding 2"],
"sources": ["source 1", "source 2"],
"summary": "...",
"data": {}
}`, task.Description, task.Input)
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeCode(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputContext := ""
if task.Input != nil {
inputJSON, _ := json.Marshal(task.Input)
inputContext = fmt.Sprintf("\n\nContext from previous tasks:\n%s", string(inputJSON))
}
prompt := fmt.Sprintf(`You are an expert programmer. Complete this coding task:
Task: %s%s
Requirements:
1. Write clean, production-ready code
2. Include error handling
3. Add necessary imports
4. Follow best practices
Respond in JSON:
{
"language": "python",
"code": "...",
"filename": "main.py",
"dependencies": ["package1", "package2"],
"explanation": "..."
}`, task.Description, inputContext)
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 8192},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
var artifacts []Artifact
if code, ok := output["code"].(string); ok {
filename := "main.py"
if fn, ok := output["filename"].(string); ok {
filename = fn
}
artifacts = append(artifacts, Artifact{
ID: uuid.New().String(),
TaskID: task.ID,
Type: ArtifactTypeCode,
Name: filename,
Content: []byte(code),
Size: int64(len(code)),
CreatedAt: time.Now(),
})
}
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
Artifacts: artifacts,
}, nil
}
func (e *Executor) executeAnalysis(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`You are a data analyst. Analyze this data/information:
Task: %s
Input data:
%s
Provide:
1. Key insights
2. Patterns observed
3. Recommendations
4. Visualizations needed (describe)
Respond in JSON:
{
"insights": ["insight 1", "insight 2"],
"patterns": ["pattern 1"],
"recommendations": ["rec 1"],
"visualizations": ["chart type 1"],
"summary": "..."
}`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeDesign(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`You are a software architect. Design a solution:
Task: %s
Context:
%s
Provide:
1. Architecture overview
2. Components and their responsibilities
3. Data flow
4. Technology recommendations
5. Implementation plan
Respond in JSON:
{
"architecture": "...",
"components": [{"name": "...", "responsibility": "..."}],
"dataFlow": "...",
"technologies": ["tech1", "tech2"],
"implementationSteps": ["step1", "step2"],
"diagram": "mermaid diagram code"
}`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeDeploy(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
if e.sandbox == nil {
return e.executeGeneric(ctx, client, task)
}
var code string
if task.Input != nil {
if c, ok := task.Input["code"].(string); ok {
code = c
}
}
if code == "" {
return e.executeGeneric(ctx, client, task)
}
sandbox, err := e.sandbox.Create(ctx, task.ID)
if err != nil {
return nil, fmt.Errorf("failed to create sandbox: %w", err)
}
defer e.sandbox.Destroy(ctx, sandbox)
result, err := e.sandbox.Execute(ctx, sandbox, code, "python")
if err != nil {
return nil, fmt.Errorf("sandbox execution failed: %w", err)
}
output := map[string]interface{}{
"stdout": result.Stdout,
"stderr": result.Stderr,
"exitCode": result.ExitCode,
"duration": result.Duration.String(),
}
var artifacts []Artifact
for name, content := range result.Files {
artifacts = append(artifacts, Artifact{
ID: uuid.New().String(),
TaskID: task.ID,
Type: ArtifactTypeFile,
Name: name,
Content: content,
Size: int64(len(content)),
CreatedAt: time.Now(),
})
}
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
Artifacts: artifacts,
}, nil
}
func (e *Executor) executeReport(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`You are a report writer. Generate a comprehensive report:
Task: %s
Data/Context:
%s
Create a well-structured report with:
1. Executive Summary
2. Key Findings
3. Detailed Analysis
4. Conclusions
5. Recommendations
Use markdown formatting.`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 8192},
})
if err != nil {
return nil, err
}
output := map[string]interface{}{
"report": response,
"format": "markdown",
"wordCount": len(strings.Fields(response)),
}
artifacts := []Artifact{
{
ID: uuid.New().String(),
TaskID: task.ID,
Type: ArtifactTypeReport,
Name: "report.md",
Content: []byte(response),
MimeType: "text/markdown",
Size: int64(len(response)),
CreatedAt: time.Now(),
},
}
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
Artifacts: artifacts,
}, nil
}
func (e *Executor) executeCommunicate(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`Generate a message/notification:
Task: %s
Context:
%s
Create an appropriate message. Respond in JSON:
{
"subject": "...",
"body": "...",
"format": "text|html",
"priority": "low|normal|high"
}`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 2048},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
output["status"] = "prepared"
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeTransform(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`Transform data as requested:
Task: %s
Input data:
%s
Perform the transformation and return the result in JSON:
{
"transformed": ...,
"format": "...",
"changes": ["change 1", "change 2"]
}`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeValidate(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`Validate the following:
Task: %s
Data to validate:
%s
Check for:
1. Correctness
2. Completeness
3. Consistency
4. Quality
Respond in JSON:
{
"valid": true|false,
"score": 0-100,
"issues": ["issue 1", "issue 2"],
"suggestions": ["suggestion 1"],
"summary": "..."
}`, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 2048},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func (e *Executor) executeGeneric(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
inputJSON, _ := json.Marshal(task.Input)
prompt := fmt.Sprintf(`Complete this task:
Task type: %s
Description: %s
Context:
%s
Provide a comprehensive result in JSON format.`, task.Type, task.Description, string(inputJSON))
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return nil, err
}
output := parseJSONOutput(response)
if len(output) == 0 {
output = map[string]interface{}{
"result": response,
}
}
return &ExecutionResult{
TaskID: task.ID,
SubTaskID: task.ID,
Output: output,
}, nil
}
func parseJSONOutput(response string) map[string]interface{} {
startIdx := strings.Index(response, "{")
endIdx := strings.LastIndex(response, "}")
if startIdx == -1 || endIdx == -1 || endIdx <= startIdx {
return map[string]interface{}{"raw": response}
}
jsonStr := response[startIdx : endIdx+1]
var output map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &output); err != nil {
return map[string]interface{}{"raw": response}
}
return output
}

View File

@@ -0,0 +1,377 @@
package computer
import (
"context"
"encoding/json"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type MemoryStore struct {
repo MemoryRepository
cache map[string][]MemoryEntry
mu sync.RWMutex
}
func NewMemoryStore(repo MemoryRepository) *MemoryStore {
return &MemoryStore{
repo: repo,
cache: make(map[string][]MemoryEntry),
}
}
func (m *MemoryStore) Store(ctx context.Context, userID string, entry *MemoryEntry) error {
if entry.ID == "" {
entry.ID = uuid.New().String()
}
entry.UserID = userID
if entry.CreatedAt.IsZero() {
entry.CreatedAt = time.Now()
}
if m.repo != nil {
if err := m.repo.Store(ctx, entry); err != nil {
return err
}
}
m.mu.Lock()
m.cache[userID] = append(m.cache[userID], *entry)
if len(m.cache[userID]) > 1000 {
m.cache[userID] = m.cache[userID][len(m.cache[userID])-500:]
}
m.mu.Unlock()
return nil
}
func (m *MemoryStore) StoreResult(ctx context.Context, userID, taskID, key string, value interface{}) error {
valueJSON, _ := json.Marshal(value)
entry := &MemoryEntry{
UserID: userID,
TaskID: taskID,
Key: key,
Value: string(valueJSON),
Type: MemoryTypeResult,
CreatedAt: time.Now(),
}
return m.Store(ctx, userID, entry)
}
func (m *MemoryStore) StoreFact(ctx context.Context, userID, key string, value interface{}, tags []string) error {
entry := &MemoryEntry{
UserID: userID,
Key: key,
Value: value,
Type: MemoryTypeFact,
Tags: tags,
CreatedAt: time.Now(),
}
return m.Store(ctx, userID, entry)
}
func (m *MemoryStore) StorePreference(ctx context.Context, userID, key string, value interface{}) error {
entry := &MemoryEntry{
UserID: userID,
Key: key,
Value: value,
Type: MemoryTypePreference,
CreatedAt: time.Now(),
}
return m.Store(ctx, userID, entry)
}
func (m *MemoryStore) StoreContext(ctx context.Context, userID, taskID, key string, value interface{}, ttl time.Duration) error {
expiresAt := time.Now().Add(ttl)
entry := &MemoryEntry{
UserID: userID,
TaskID: taskID,
Key: key,
Value: value,
Type: MemoryTypeContext,
CreatedAt: time.Now(),
ExpiresAt: &expiresAt,
}
return m.Store(ctx, userID, entry)
}
func (m *MemoryStore) Recall(ctx context.Context, userID string, query string, limit int) ([]MemoryEntry, error) {
if m.repo != nil {
entries, err := m.repo.Search(ctx, userID, query, limit)
if err == nil && len(entries) > 0 {
return entries, nil
}
}
m.mu.RLock()
cached := m.cache[userID]
m.mu.RUnlock()
if len(cached) == 0 {
return nil, nil
}
queryLower := strings.ToLower(query)
queryTerms := strings.Fields(queryLower)
type scored struct {
entry MemoryEntry
score int
}
var results []scored
now := time.Now()
for _, entry := range cached {
if entry.ExpiresAt != nil && entry.ExpiresAt.Before(now) {
continue
}
score := 0
keyLower := strings.ToLower(entry.Key)
for _, term := range queryTerms {
if strings.Contains(keyLower, term) {
score += 3
}
}
if valueStr, ok := entry.Value.(string); ok {
valueLower := strings.ToLower(valueStr)
for _, term := range queryTerms {
if strings.Contains(valueLower, term) {
score += 1
}
}
}
for _, tag := range entry.Tags {
tagLower := strings.ToLower(tag)
for _, term := range queryTerms {
if strings.Contains(tagLower, term) {
score += 2
}
}
}
if score > 0 {
results = append(results, scored{entry: entry, score: score})
}
}
for i := 0; i < len(results)-1; i++ {
for j := i + 1; j < len(results); j++ {
if results[j].score > results[i].score {
results[i], results[j] = results[j], results[i]
}
}
}
if len(results) > limit {
results = results[:limit]
}
entries := make([]MemoryEntry, len(results))
for i, r := range results {
entries[i] = r.entry
}
return entries, nil
}
func (m *MemoryStore) GetByUser(ctx context.Context, userID string, limit int) ([]MemoryEntry, error) {
if m.repo != nil {
return m.repo.GetByUser(ctx, userID, limit)
}
m.mu.RLock()
cached := m.cache[userID]
m.mu.RUnlock()
if len(cached) > limit {
return cached[len(cached)-limit:], nil
}
return cached, nil
}
func (m *MemoryStore) GetByTask(ctx context.Context, taskID string) ([]MemoryEntry, error) {
if m.repo != nil {
return m.repo.GetByTask(ctx, taskID)
}
var result []MemoryEntry
m.mu.RLock()
for _, entries := range m.cache {
for _, e := range entries {
if e.TaskID == taskID {
result = append(result, e)
}
}
}
m.mu.RUnlock()
return result, nil
}
func (m *MemoryStore) GetTaskContext(ctx context.Context, taskID string) (map[string]interface{}, error) {
entries, err := m.GetByTask(ctx, taskID)
if err != nil {
return nil, err
}
context := make(map[string]interface{})
for _, e := range entries {
context[e.Key] = e.Value
}
return context, nil
}
func (m *MemoryStore) GetUserContext(ctx context.Context, userID string) (map[string]interface{}, error) {
entries, err := m.GetByUser(ctx, userID, 100)
if err != nil {
return nil, err
}
context := make(map[string]interface{})
for _, e := range entries {
if e.Type == MemoryTypePreference || e.Type == MemoryTypeFact {
context[e.Key] = e.Value
}
}
return context, nil
}
func (m *MemoryStore) GetPreferences(ctx context.Context, userID string) (map[string]interface{}, error) {
entries, err := m.GetByUser(ctx, userID, 100)
if err != nil {
return nil, err
}
prefs := make(map[string]interface{})
for _, e := range entries {
if e.Type == MemoryTypePreference {
prefs[e.Key] = e.Value
}
}
return prefs, nil
}
func (m *MemoryStore) GetFacts(ctx context.Context, userID string) ([]MemoryEntry, error) {
entries, err := m.GetByUser(ctx, userID, 100)
if err != nil {
return nil, err
}
var facts []MemoryEntry
for _, e := range entries {
if e.Type == MemoryTypeFact {
facts = append(facts, e)
}
}
return facts, nil
}
func (m *MemoryStore) Delete(ctx context.Context, id string) error {
if m.repo != nil {
return m.repo.Delete(ctx, id)
}
m.mu.Lock()
for userID, entries := range m.cache {
for i, e := range entries {
if e.ID == id {
m.cache[userID] = append(entries[:i], entries[i+1:]...)
break
}
}
}
m.mu.Unlock()
return nil
}
func (m *MemoryStore) Clear(ctx context.Context, userID string) error {
m.mu.Lock()
delete(m.cache, userID)
m.mu.Unlock()
return nil
}
func (m *MemoryStore) ClearTask(ctx context.Context, taskID string) error {
m.mu.Lock()
for userID, entries := range m.cache {
var filtered []MemoryEntry
for _, e := range entries {
if e.TaskID != taskID {
filtered = append(filtered, e)
}
}
m.cache[userID] = filtered
}
m.mu.Unlock()
return nil
}
func (m *MemoryStore) Cleanup(ctx context.Context) error {
now := time.Now()
m.mu.Lock()
for userID, entries := range m.cache {
var valid []MemoryEntry
for _, e := range entries {
if e.ExpiresAt == nil || e.ExpiresAt.After(now) {
valid = append(valid, e)
}
}
m.cache[userID] = valid
}
m.mu.Unlock()
return nil
}
func (m *MemoryStore) Stats(userID string) map[string]int {
m.mu.RLock()
entries := m.cache[userID]
m.mu.RUnlock()
stats := map[string]int{
"total": len(entries),
"facts": 0,
"preferences": 0,
"context": 0,
"results": 0,
}
for _, e := range entries {
switch e.Type {
case MemoryTypeFact:
stats["facts"]++
case MemoryTypePreference:
stats["preferences"]++
case MemoryTypeContext:
stats["context"]++
case MemoryTypeResult:
stats["results"]++
}
}
return stats
}

View File

@@ -0,0 +1,371 @@
package computer
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
)
type Planner struct {
registry *llm.ModelRegistry
}
func NewPlanner(registry *llm.ModelRegistry) *Planner {
return &Planner{
registry: registry,
}
}
func (p *Planner) Plan(ctx context.Context, query string, memory map[string]interface{}) (*TaskPlan, error) {
client, _, err := p.registry.GetBest(llm.CapReasoning)
if err != nil {
client, _, err = p.registry.GetBest(llm.CapCoding)
if err != nil {
return nil, fmt.Errorf("no suitable model for planning: %w", err)
}
}
memoryContext := ""
if len(memory) > 0 {
memoryJSON, _ := json.Marshal(memory)
memoryContext = fmt.Sprintf("\n\nUser context and memory:\n%s", string(memoryJSON))
}
prompt := fmt.Sprintf(`You are a task planning AI. Analyze this query and create an execution plan.
Query: %s%s
Break this into subtasks. Each subtask should be:
1. Atomic - one clear action
2. Independent where possible (for parallel execution)
3. Have clear dependencies when needed
Available task types:
- research: Search web, gather information
- code: Write/generate code
- analysis: Analyze data, extract insights
- design: Design architecture, create plans
- deploy: Deploy applications, run code
- monitor: Set up monitoring, tracking
- report: Generate reports, summaries
- communicate: Send emails, messages
- transform: Convert data formats
- validate: Check, verify results
For each subtask specify:
- type: one of the task types above
- description: what to do
- dependencies: list of subtask IDs this depends on (empty if none)
- capabilities: required AI capabilities (reasoning, coding, search, creative, fast, long_context, vision, math)
Respond in JSON format:
{
"summary": "Brief summary of the plan",
"subtasks": [
{
"id": "1",
"type": "research",
"description": "Search for...",
"dependencies": [],
"capabilities": ["search"]
},
{
"id": "2",
"type": "code",
"description": "Write code to...",
"dependencies": ["1"],
"capabilities": ["coding"]
}
],
"estimatedCost": 0.05,
"estimatedTimeSeconds": 120
}
Create 3-10 subtasks. Be specific and actionable.`, query, memoryContext)
messages := []llm.Message{
{Role: llm.RoleUser, Content: prompt},
}
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return p.createDefaultPlan(query), nil
}
plan, err := p.parsePlanResponse(response)
if err != nil {
return p.createDefaultPlan(query), nil
}
plan.Query = query
plan.ExecutionOrder = p.calculateExecutionOrder(plan.SubTasks)
return plan, nil
}
func (p *Planner) parsePlanResponse(response string) (*TaskPlan, error) {
jsonRegex := regexp.MustCompile(`\{[\s\S]*\}`)
jsonMatch := jsonRegex.FindString(response)
if jsonMatch == "" {
return nil, fmt.Errorf("no JSON found in response")
}
var rawPlan struct {
Summary string `json:"summary"`
EstimatedCost float64 `json:"estimatedCost"`
EstimatedTimeSeconds int `json:"estimatedTimeSeconds"`
SubTasks []struct {
ID string `json:"id"`
Type string `json:"type"`
Description string `json:"description"`
Dependencies []string `json:"dependencies"`
Capabilities []string `json:"capabilities"`
} `json:"subtasks"`
}
if err := json.Unmarshal([]byte(jsonMatch), &rawPlan); err != nil {
return nil, fmt.Errorf("failed to parse plan JSON: %w", err)
}
plan := &TaskPlan{
Summary: rawPlan.Summary,
EstimatedCost: rawPlan.EstimatedCost,
EstimatedTime: rawPlan.EstimatedTimeSeconds,
SubTasks: make([]SubTask, len(rawPlan.SubTasks)),
}
for i, st := range rawPlan.SubTasks {
caps := make([]llm.ModelCapability, len(st.Capabilities))
for j, c := range st.Capabilities {
caps[j] = llm.ModelCapability(c)
}
plan.SubTasks[i] = SubTask{
ID: st.ID,
Type: TaskType(st.Type),
Description: st.Description,
Dependencies: st.Dependencies,
RequiredCaps: caps,
Status: StatusPending,
MaxRetries: 3,
}
}
return plan, nil
}
func (p *Planner) calculateExecutionOrder(subTasks []SubTask) [][]string {
taskMap := make(map[string]*SubTask)
for i := range subTasks {
taskMap[subTasks[i].ID] = &subTasks[i]
}
inDegree := make(map[string]int)
for _, st := range subTasks {
if _, ok := inDegree[st.ID]; !ok {
inDegree[st.ID] = 0
}
for _, dep := range st.Dependencies {
inDegree[st.ID]++
if _, ok := inDegree[dep]; !ok {
inDegree[dep] = 0
}
}
}
var order [][]string
completed := make(map[string]bool)
for len(completed) < len(subTasks) {
var wave []string
for _, st := range subTasks {
if completed[st.ID] {
continue
}
canExecute := true
for _, dep := range st.Dependencies {
if !completed[dep] {
canExecute = false
break
}
}
if canExecute {
wave = append(wave, st.ID)
}
}
if len(wave) == 0 {
for _, st := range subTasks {
if !completed[st.ID] {
wave = append(wave, st.ID)
}
}
}
for _, id := range wave {
completed[id] = true
}
order = append(order, wave)
}
return order
}
func (p *Planner) createDefaultPlan(query string) *TaskPlan {
queryLower := strings.ToLower(query)
subTasks := []SubTask{
{
ID: uuid.New().String(),
Type: TaskResearch,
Description: "Research and gather information about: " + query,
Dependencies: []string{},
RequiredCaps: []llm.ModelCapability{llm.CapSearch},
Status: StatusPending,
MaxRetries: 3,
},
}
if strings.Contains(queryLower, "код") || strings.Contains(queryLower, "code") ||
strings.Contains(queryLower, "приложение") || strings.Contains(queryLower, "app") ||
strings.Contains(queryLower, "скрипт") || strings.Contains(queryLower, "script") {
subTasks = append(subTasks, SubTask{
ID: uuid.New().String(),
Type: TaskDesign,
Description: "Design architecture and structure",
Dependencies: []string{subTasks[0].ID},
RequiredCaps: []llm.ModelCapability{llm.CapReasoning},
Status: StatusPending,
MaxRetries: 3,
})
subTasks = append(subTasks, SubTask{
ID: uuid.New().String(),
Type: TaskCode,
Description: "Generate code implementation",
Dependencies: []string{subTasks[1].ID},
RequiredCaps: []llm.ModelCapability{llm.CapCoding},
Status: StatusPending,
MaxRetries: 3,
})
}
if strings.Contains(queryLower, "отчёт") || strings.Contains(queryLower, "report") ||
strings.Contains(queryLower, "анализ") || strings.Contains(queryLower, "analysis") {
subTasks = append(subTasks, SubTask{
ID: uuid.New().String(),
Type: TaskAnalysis,
Description: "Analyze gathered information",
Dependencies: []string{subTasks[0].ID},
RequiredCaps: []llm.ModelCapability{llm.CapReasoning},
Status: StatusPending,
MaxRetries: 3,
})
subTasks = append(subTasks, SubTask{
ID: uuid.New().String(),
Type: TaskReport,
Description: "Generate comprehensive report",
Dependencies: []string{subTasks[len(subTasks)-1].ID},
RequiredCaps: []llm.ModelCapability{llm.CapCreative},
Status: StatusPending,
MaxRetries: 3,
})
}
if strings.Contains(queryLower, "email") || strings.Contains(queryLower, "письмо") ||
strings.Contains(queryLower, "telegram") || strings.Contains(queryLower, "отправ") {
subTasks = append(subTasks, SubTask{
ID: uuid.New().String(),
Type: TaskCommunicate,
Description: "Send notification/message",
Dependencies: []string{subTasks[len(subTasks)-1].ID},
RequiredCaps: []llm.ModelCapability{llm.CapFast},
Status: StatusPending,
MaxRetries: 3,
})
}
plan := &TaskPlan{
Query: query,
Summary: "Auto-generated plan for: " + query,
SubTasks: subTasks,
EstimatedCost: float64(len(subTasks)) * 0.01,
EstimatedTime: len(subTasks) * 30,
}
plan.ExecutionOrder = p.calculateExecutionOrder(subTasks)
return plan
}
func (p *Planner) Replan(ctx context.Context, plan *TaskPlan, newContext string) (*TaskPlan, error) {
completedTasks := make([]SubTask, 0)
pendingTasks := make([]SubTask, 0)
for _, st := range plan.SubTasks {
if st.Status == StatusCompleted {
completedTasks = append(completedTasks, st)
} else if st.Status == StatusPending || st.Status == StatusFailed {
pendingTasks = append(pendingTasks, st)
}
}
completedJSON, _ := json.Marshal(completedTasks)
pendingJSON, _ := json.Marshal(pendingTasks)
client, _, err := p.registry.GetBest(llm.CapReasoning)
if err != nil {
return plan, nil
}
prompt := fmt.Sprintf(`You need to replan a task based on new context.
Original query: %s
Completed subtasks:
%s
Pending subtasks:
%s
New context/feedback:
%s
Adjust the plan. Keep completed tasks, modify or remove pending tasks as needed.
Add new subtasks if the new context requires it.
Respond in the same JSON format as before.`, plan.Query, string(completedJSON), string(pendingJSON), newContext)
messages := []llm.Message{
{Role: llm.RoleUser, Content: prompt},
}
response, err := client.GenerateText(ctx, llm.StreamRequest{
Messages: messages,
Options: llm.StreamOptions{MaxTokens: 4096},
})
if err != nil {
return plan, nil
}
newPlan, err := p.parsePlanResponse(response)
if err != nil {
return plan, nil
}
newPlan.Query = plan.Query
newPlan.ExecutionOrder = p.calculateExecutionOrder(newPlan.SubTasks)
return newPlan, nil
}

View File

@@ -0,0 +1,244 @@
package computer
import (
"errors"
"sort"
"github.com/gooseek/backend/internal/llm"
)
type RoutingRule struct {
TaskType TaskType
Preferred []llm.ModelCapability
Fallback []string
MaxCost float64
MaxLatency int
}
type Router struct {
registry *llm.ModelRegistry
rules map[TaskType]RoutingRule
}
func NewRouter(registry *llm.ModelRegistry) *Router {
r := &Router{
registry: registry,
rules: make(map[TaskType]RoutingRule),
}
r.rules[TaskResearch] = RoutingRule{
TaskType: TaskResearch,
Preferred: []llm.ModelCapability{llm.CapSearch, llm.CapLongContext},
Fallback: []string{"gemini-1.5-pro", "gpt-4o"},
MaxCost: 0.1,
}
r.rules[TaskCode] = RoutingRule{
TaskType: TaskCode,
Preferred: []llm.ModelCapability{llm.CapCoding},
Fallback: []string{"claude-3-sonnet", "claude-3-opus", "gpt-4o"},
MaxCost: 0.2,
}
r.rules[TaskAnalysis] = RoutingRule{
TaskType: TaskAnalysis,
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapMath},
Fallback: []string{"claude-3-opus", "gpt-4o"},
MaxCost: 0.15,
}
r.rules[TaskDesign] = RoutingRule{
TaskType: TaskDesign,
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCreative},
Fallback: []string{"claude-3-opus", "gpt-4o"},
MaxCost: 0.15,
}
r.rules[TaskDeploy] = RoutingRule{
TaskType: TaskDeploy,
Preferred: []llm.ModelCapability{llm.CapCoding, llm.CapFast},
Fallback: []string{"claude-3-sonnet", "gpt-4o-mini"},
MaxCost: 0.05,
}
r.rules[TaskMonitor] = RoutingRule{
TaskType: TaskMonitor,
Preferred: []llm.ModelCapability{llm.CapFast},
Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"},
MaxCost: 0.02,
}
r.rules[TaskReport] = RoutingRule{
TaskType: TaskReport,
Preferred: []llm.ModelCapability{llm.CapCreative, llm.CapLongContext},
Fallback: []string{"claude-3-opus", "gpt-4o"},
MaxCost: 0.1,
}
r.rules[TaskCommunicate] = RoutingRule{
TaskType: TaskCommunicate,
Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCreative},
Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"},
MaxCost: 0.02,
}
r.rules[TaskTransform] = RoutingRule{
TaskType: TaskTransform,
Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCoding},
Fallback: []string{"gpt-4o-mini", "claude-3-sonnet"},
MaxCost: 0.03,
}
r.rules[TaskValidate] = RoutingRule{
TaskType: TaskValidate,
Preferred: []llm.ModelCapability{llm.CapReasoning},
Fallback: []string{"gpt-4o", "claude-3-sonnet"},
MaxCost: 0.05,
}
return r
}
func (r *Router) Route(task *SubTask, budget float64) (llm.Client, llm.ModelSpec, error) {
if task.ModelID != "" {
client, spec, err := r.registry.GetByID(task.ModelID)
if err == nil && spec.CostPer1K <= budget {
return client, spec, nil
}
}
if len(task.RequiredCaps) > 0 {
for _, cap := range task.RequiredCaps {
client, spec, err := r.registry.GetBest(cap)
if err == nil && spec.CostPer1K <= budget {
return client, spec, nil
}
}
}
rule, ok := r.rules[task.Type]
if ok {
for _, cap := range rule.Preferred {
client, spec, err := r.registry.GetBest(cap)
if err == nil && spec.CostPer1K <= budget {
return client, spec, nil
}
}
for _, modelID := range rule.Fallback {
client, spec, err := r.registry.GetByID(modelID)
if err == nil && spec.CostPer1K <= budget {
return client, spec, nil
}
}
}
models := r.registry.GetAll()
if len(models) == 0 {
return nil, llm.ModelSpec{}, errors.New("no models available")
}
sort.Slice(models, func(i, j int) bool {
return models[i].CostPer1K < models[j].CostPer1K
})
for _, spec := range models {
if spec.CostPer1K <= budget {
client, err := r.registry.GetClient(spec.ID)
if err == nil {
return client, spec, nil
}
}
}
client, err := r.registry.GetClient(models[0].ID)
if err != nil {
return nil, llm.ModelSpec{}, err
}
return client, models[0], nil
}
func (r *Router) RouteMultiple(task *SubTask, count int, budget float64) ([]llm.Client, []llm.ModelSpec, error) {
var clients []llm.Client
var specs []llm.ModelSpec
usedModels := make(map[string]bool)
perModelBudget := budget / float64(count)
rule, ok := r.rules[task.Type]
if !ok {
rule = RoutingRule{
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCoding, llm.CapFast},
}
}
for _, cap := range rule.Preferred {
if len(clients) >= count {
break
}
models := r.registry.GetAllWithCapability(cap)
for _, spec := range models {
if len(clients) >= count {
break
}
if usedModels[spec.ID] {
continue
}
if spec.CostPer1K > perModelBudget {
continue
}
client, err := r.registry.GetClient(spec.ID)
if err == nil {
clients = append(clients, client)
specs = append(specs, spec)
usedModels[spec.ID] = true
}
}
}
if len(clients) < count {
models := r.registry.GetAll()
for _, spec := range models {
if len(clients) >= count {
break
}
if usedModels[spec.ID] {
continue
}
client, err := r.registry.GetClient(spec.ID)
if err == nil {
clients = append(clients, client)
specs = append(specs, spec)
usedModels[spec.ID] = true
}
}
}
if len(clients) == 0 {
return nil, nil, errors.New("no models available for consensus")
}
return clients, specs, nil
}
func (r *Router) SetRule(taskType TaskType, rule RoutingRule) {
r.rules[taskType] = rule
}
func (r *Router) GetRule(taskType TaskType) (RoutingRule, bool) {
rule, ok := r.rules[taskType]
return rule, ok
}
func (r *Router) EstimateCost(task *SubTask, inputTokens, outputTokens int) float64 {
_, spec, err := r.Route(task, 1.0)
if err != nil {
return 0.01
}
totalTokens := inputTokens + outputTokens
return spec.CostPer1K * float64(totalTokens) / 1000.0
}

View File

@@ -0,0 +1,431 @@
package computer
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type SandboxConfig struct {
Image string
Timeout time.Duration
MemoryLimit string
CPULimit string
NetworkMode string
WorkDir string
MaxFileSize int64
AllowNetwork bool
}
func DefaultSandboxConfig() SandboxConfig {
return SandboxConfig{
Image: "gooseek/sandbox:latest",
Timeout: 5 * time.Minute,
MemoryLimit: "512m",
CPULimit: "1.0",
NetworkMode: "none",
WorkDir: "/workspace",
MaxFileSize: 10 * 1024 * 1024,
AllowNetwork: false,
}
}
type Sandbox struct {
ID string
ContainerID string
WorkDir string
Status string
TaskID string
CreatedAt time.Time
}
type SandboxManager struct {
cfg SandboxConfig
sandboxes map[string]*Sandbox
mu sync.RWMutex
useDocker bool
}
func NewSandboxManager(cfg SandboxConfig) *SandboxManager {
if cfg.Timeout == 0 {
cfg.Timeout = 5 * time.Minute
}
if cfg.MemoryLimit == "" {
cfg.MemoryLimit = "512m"
}
if cfg.WorkDir == "" {
cfg.WorkDir = "/workspace"
}
useDocker := isDockerAvailable()
return &SandboxManager{
cfg: cfg,
sandboxes: make(map[string]*Sandbox),
useDocker: useDocker,
}
}
func isDockerAvailable() bool {
cmd := exec.Command("docker", "version")
return cmd.Run() == nil
}
func (sm *SandboxManager) Create(ctx context.Context, taskID string) (*Sandbox, error) {
sandboxID := uuid.New().String()[:8]
sandbox := &Sandbox{
ID: sandboxID,
TaskID: taskID,
Status: "creating",
CreatedAt: time.Now(),
}
if sm.useDocker {
workDir, err := os.MkdirTemp("", fmt.Sprintf("sandbox-%s-", sandboxID))
if err != nil {
return nil, fmt.Errorf("failed to create temp dir: %w", err)
}
sandbox.WorkDir = workDir
args := []string{
"create",
"--name", fmt.Sprintf("gooseek-sandbox-%s", sandboxID),
"-v", fmt.Sprintf("%s:%s", workDir, sm.cfg.WorkDir),
"-w", sm.cfg.WorkDir,
"--memory", sm.cfg.MemoryLimit,
"--cpus", sm.cfg.CPULimit,
}
if !sm.cfg.AllowNetwork {
args = append(args, "--network", "none")
}
args = append(args, sm.cfg.Image, "tail", "-f", "/dev/null")
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
os.RemoveAll(workDir)
return nil, fmt.Errorf("failed to create container: %w - %s", err, string(output))
}
sandbox.ContainerID = strings.TrimSpace(string(output))
startCmd := exec.CommandContext(ctx, "docker", "start", sandbox.ContainerID)
if err := startCmd.Run(); err != nil {
sm.cleanupContainer(sandbox)
return nil, fmt.Errorf("failed to start container: %w", err)
}
} else {
workDir, err := os.MkdirTemp("", fmt.Sprintf("sandbox-%s-", sandboxID))
if err != nil {
return nil, fmt.Errorf("failed to create temp dir: %w", err)
}
sandbox.WorkDir = workDir
}
sandbox.Status = "running"
sm.mu.Lock()
sm.sandboxes[sandboxID] = sandbox
sm.mu.Unlock()
return sandbox, nil
}
func (sm *SandboxManager) Execute(ctx context.Context, sandbox *Sandbox, code string, lang string) (*SandboxResult, error) {
ctx, cancel := context.WithTimeout(ctx, sm.cfg.Timeout)
defer cancel()
startTime := time.Now()
filename, err := sm.writeCodeFile(sandbox, code, lang)
if err != nil {
return nil, err
}
var cmd *exec.Cmd
var stdout, stderr bytes.Buffer
if sm.useDocker {
runCmd := sm.getRunCommand(lang, filename)
cmd = exec.CommandContext(ctx, "docker", "exec", sandbox.ContainerID, "sh", "-c", runCmd)
} else {
cmd = sm.getLocalCommand(ctx, lang, filepath.Join(sandbox.WorkDir, filename))
}
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else if ctx.Err() == context.DeadlineExceeded {
return &SandboxResult{
Stderr: "Execution timeout exceeded",
ExitCode: -1,
Duration: time.Since(startTime),
}, nil
}
}
files, _ := sm.collectOutputFiles(sandbox)
return &SandboxResult{
Stdout: stdout.String(),
Stderr: stderr.String(),
ExitCode: exitCode,
Files: files,
Duration: time.Since(startTime),
}, nil
}
func (sm *SandboxManager) RunCommand(ctx context.Context, sandbox *Sandbox, command string) (*SandboxResult, error) {
ctx, cancel := context.WithTimeout(ctx, sm.cfg.Timeout)
defer cancel()
startTime := time.Now()
var cmd *exec.Cmd
var stdout, stderr bytes.Buffer
if sm.useDocker {
cmd = exec.CommandContext(ctx, "docker", "exec", sandbox.ContainerID, "sh", "-c", command)
} else {
cmd = exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = sandbox.WorkDir
}
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
}
}
return &SandboxResult{
Stdout: stdout.String(),
Stderr: stderr.String(),
ExitCode: exitCode,
Duration: time.Since(startTime),
}, nil
}
func (sm *SandboxManager) WriteFile(ctx context.Context, sandbox *Sandbox, path string, content []byte) error {
if int64(len(content)) > sm.cfg.MaxFileSize {
return fmt.Errorf("file size exceeds limit: %d > %d", len(content), sm.cfg.MaxFileSize)
}
fullPath := filepath.Join(sandbox.WorkDir, path)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
return os.WriteFile(fullPath, content, 0644)
}
func (sm *SandboxManager) ReadFile(ctx context.Context, sandbox *Sandbox, path string) ([]byte, error) {
fullPath := filepath.Join(sandbox.WorkDir, path)
return os.ReadFile(fullPath)
}
func (sm *SandboxManager) Destroy(ctx context.Context, sandbox *Sandbox) error {
sm.mu.Lock()
delete(sm.sandboxes, sandbox.ID)
sm.mu.Unlock()
if sm.useDocker && sandbox.ContainerID != "" {
sm.cleanupContainer(sandbox)
}
if sandbox.WorkDir != "" {
os.RemoveAll(sandbox.WorkDir)
}
return nil
}
func (sm *SandboxManager) cleanupContainer(sandbox *Sandbox) {
exec.Command("docker", "stop", sandbox.ContainerID).Run()
exec.Command("docker", "rm", "-f", sandbox.ContainerID).Run()
}
func (sm *SandboxManager) writeCodeFile(sandbox *Sandbox, code string, lang string) (string, error) {
var filename string
switch lang {
case "python", "py":
filename = "main.py"
case "javascript", "js", "node":
filename = "main.js"
case "typescript", "ts":
filename = "main.ts"
case "go", "golang":
filename = "main.go"
case "bash", "sh", "shell":
filename = "script.sh"
case "ruby", "rb":
filename = "main.rb"
default:
filename = "main.txt"
}
fullPath := filepath.Join(sandbox.WorkDir, filename)
if err := os.WriteFile(fullPath, []byte(code), 0755); err != nil {
return "", fmt.Errorf("failed to write code file: %w", err)
}
return filename, nil
}
func (sm *SandboxManager) getRunCommand(lang, filename string) string {
switch lang {
case "python", "py":
return fmt.Sprintf("python3 %s/%s", sm.cfg.WorkDir, filename)
case "javascript", "js", "node":
return fmt.Sprintf("node %s/%s", sm.cfg.WorkDir, filename)
case "typescript", "ts":
return fmt.Sprintf("npx ts-node %s/%s", sm.cfg.WorkDir, filename)
case "go", "golang":
return fmt.Sprintf("go run %s/%s", sm.cfg.WorkDir, filename)
case "bash", "sh", "shell":
return fmt.Sprintf("bash %s/%s", sm.cfg.WorkDir, filename)
case "ruby", "rb":
return fmt.Sprintf("ruby %s/%s", sm.cfg.WorkDir, filename)
default:
return fmt.Sprintf("cat %s/%s", sm.cfg.WorkDir, filename)
}
}
func (sm *SandboxManager) getLocalCommand(ctx context.Context, lang, filepath string) *exec.Cmd {
switch lang {
case "python", "py":
return exec.CommandContext(ctx, "python3", filepath)
case "javascript", "js", "node":
return exec.CommandContext(ctx, "node", filepath)
case "go", "golang":
return exec.CommandContext(ctx, "go", "run", filepath)
case "bash", "sh", "shell":
return exec.CommandContext(ctx, "bash", filepath)
case "ruby", "rb":
return exec.CommandContext(ctx, "ruby", filepath)
default:
return exec.CommandContext(ctx, "cat", filepath)
}
}
func (sm *SandboxManager) collectOutputFiles(sandbox *Sandbox) (map[string][]byte, error) {
files := make(map[string][]byte)
err := filepath.Walk(sandbox.WorkDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(sandbox.WorkDir, path)
if err != nil {
return nil
}
if strings.HasPrefix(relPath, "main.") || strings.HasPrefix(relPath, "script.") {
return nil
}
if info.Size() > sm.cfg.MaxFileSize {
return nil
}
content, err := os.ReadFile(path)
if err != nil {
return nil
}
files[relPath] = content
return nil
})
return files, err
}
func (sm *SandboxManager) ListSandboxes() []*Sandbox {
sm.mu.RLock()
defer sm.mu.RUnlock()
result := make([]*Sandbox, 0, len(sm.sandboxes))
for _, s := range sm.sandboxes {
result = append(result, s)
}
return result
}
func (sm *SandboxManager) GetSandbox(id string) (*Sandbox, bool) {
sm.mu.RLock()
defer sm.mu.RUnlock()
s, ok := sm.sandboxes[id]
return s, ok
}
func (sm *SandboxManager) CopyToContainer(ctx context.Context, sandbox *Sandbox, src string, dst string) error {
if !sm.useDocker {
srcData, err := os.ReadFile(src)
if err != nil {
return err
}
return sm.WriteFile(ctx, sandbox, dst, srcData)
}
cmd := exec.CommandContext(ctx, "docker", "cp", src, fmt.Sprintf("%s:%s", sandbox.ContainerID, dst))
return cmd.Run()
}
func (sm *SandboxManager) CopyFromContainer(ctx context.Context, sandbox *Sandbox, src string, dst string) error {
if !sm.useDocker {
srcPath := filepath.Join(sandbox.WorkDir, src)
srcData, err := os.ReadFile(srcPath)
if err != nil {
return err
}
return os.WriteFile(dst, srcData, 0644)
}
cmd := exec.CommandContext(ctx, "docker", "cp", fmt.Sprintf("%s:%s", sandbox.ContainerID, src), dst)
return cmd.Run()
}
func (sm *SandboxManager) StreamLogs(ctx context.Context, sandbox *Sandbox) (io.ReadCloser, error) {
if !sm.useDocker {
return nil, fmt.Errorf("streaming not supported without Docker")
}
cmd := exec.CommandContext(ctx, "docker", "logs", "-f", sandbox.ContainerID)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
if err := cmd.Start(); err != nil {
return nil, err
}
return stdout, nil
}

View File

@@ -0,0 +1,386 @@
package computer
import (
"context"
"log"
"sync"
"time"
"github.com/robfig/cron/v3"
)
type Scheduler struct {
taskRepo TaskRepository
computer *Computer
cron *cron.Cron
jobs map[string]cron.EntryID
running map[string]bool
mu sync.RWMutex
stopCh chan struct{}
}
func NewScheduler(taskRepo TaskRepository, computer *Computer) *Scheduler {
return &Scheduler{
taskRepo: taskRepo,
computer: computer,
cron: cron.New(cron.WithSeconds()),
jobs: make(map[string]cron.EntryID),
running: make(map[string]bool),
stopCh: make(chan struct{}),
}
}
func (s *Scheduler) Start(ctx context.Context) {
s.cron.Start()
go s.pollScheduledTasks(ctx)
log.Println("[Scheduler] Started")
}
func (s *Scheduler) Stop() {
close(s.stopCh)
s.cron.Stop()
log.Println("[Scheduler] Stopped")
}
func (s *Scheduler) pollScheduledTasks(ctx context.Context) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
s.loadScheduledTasks(ctx)
for {
select {
case <-ctx.Done():
return
case <-s.stopCh:
return
case <-ticker.C:
s.checkAndExecute(ctx)
}
}
}
func (s *Scheduler) loadScheduledTasks(ctx context.Context) {
tasks, err := s.taskRepo.GetScheduled(ctx)
if err != nil {
log.Printf("[Scheduler] Failed to load scheduled tasks: %v", err)
return
}
for _, task := range tasks {
if task.Schedule != nil && task.Schedule.Enabled {
s.scheduleTask(&task)
}
}
log.Printf("[Scheduler] Loaded %d scheduled tasks", len(tasks))
}
func (s *Scheduler) scheduleTask(task *ComputerTask) error {
s.mu.Lock()
defer s.mu.Unlock()
if oldID, exists := s.jobs[task.ID]; exists {
s.cron.Remove(oldID)
}
if task.Schedule == nil || !task.Schedule.Enabled {
return nil
}
var entryID cron.EntryID
var err error
switch task.Schedule.Type {
case "cron":
if task.Schedule.CronExpr == "" {
return nil
}
entryID, err = s.cron.AddFunc(task.Schedule.CronExpr, func() {
s.executeScheduledTask(task.ID)
})
case "interval":
if task.Schedule.Interval <= 0 {
return nil
}
cronExpr := s.intervalToCron(task.Schedule.Interval)
entryID, err = s.cron.AddFunc(cronExpr, func() {
s.executeScheduledTask(task.ID)
})
case "once":
go func() {
if task.Schedule.NextRun.After(time.Now()) {
time.Sleep(time.Until(task.Schedule.NextRun))
}
s.executeScheduledTask(task.ID)
}()
return nil
case "daily":
entryID, err = s.cron.AddFunc("0 0 9 * * *", func() {
s.executeScheduledTask(task.ID)
})
case "hourly":
entryID, err = s.cron.AddFunc("0 0 * * * *", func() {
s.executeScheduledTask(task.ID)
})
case "weekly":
entryID, err = s.cron.AddFunc("0 0 9 * * 1", func() {
s.executeScheduledTask(task.ID)
})
case "monthly":
entryID, err = s.cron.AddFunc("0 0 9 1 * *", func() {
s.executeScheduledTask(task.ID)
})
default:
return nil
}
if err != nil {
log.Printf("[Scheduler] Failed to schedule task %s: %v", task.ID, err)
return err
}
s.jobs[task.ID] = entryID
log.Printf("[Scheduler] Scheduled task %s with type %s", task.ID, task.Schedule.Type)
return nil
}
func (s *Scheduler) intervalToCron(seconds int) string {
if seconds < 60 {
return "*/30 * * * * *"
}
if seconds < 3600 {
minutes := seconds / 60
return "0 */" + itoa(minutes) + " * * * *"
}
if seconds < 86400 {
hours := seconds / 3600
return "0 0 */" + itoa(hours) + " * * *"
}
return "0 0 0 * * *"
}
func itoa(i int) string {
if i < 10 {
return string(rune('0' + i))
}
return ""
}
func (s *Scheduler) executeScheduledTask(taskID string) {
s.mu.Lock()
if s.running[taskID] {
s.mu.Unlock()
log.Printf("[Scheduler] Task %s is already running, skipping", taskID)
return
}
s.running[taskID] = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.running, taskID)
s.mu.Unlock()
}()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
task, err := s.taskRepo.GetByID(ctx, taskID)
if err != nil {
log.Printf("[Scheduler] Failed to get task %s: %v", taskID, err)
return
}
if task.Schedule != nil {
if task.Schedule.ExpiresAt != nil && time.Now().After(*task.Schedule.ExpiresAt) {
log.Printf("[Scheduler] Task %s has expired, removing", taskID)
s.Cancel(taskID)
return
}
if task.Schedule.MaxRuns > 0 && task.Schedule.RunCount >= task.Schedule.MaxRuns {
log.Printf("[Scheduler] Task %s reached max runs (%d), removing", taskID, task.Schedule.MaxRuns)
s.Cancel(taskID)
return
}
}
log.Printf("[Scheduler] Executing scheduled task %s (run #%d)", taskID, task.RunCount+1)
_, err = s.computer.Execute(ctx, task.UserID, task.Query, ExecuteOptions{
Async: false,
Context: task.Memory,
})
if err != nil {
log.Printf("[Scheduler] Task %s execution failed: %v", taskID, err)
} else {
log.Printf("[Scheduler] Task %s completed successfully", taskID)
}
task.RunCount++
if task.Schedule != nil {
task.Schedule.RunCount = task.RunCount
task.Schedule.NextRun = s.calculateNextRun(task.Schedule)
task.NextRunAt = &task.Schedule.NextRun
}
task.UpdatedAt = time.Now()
if err := s.taskRepo.Update(ctx, task); err != nil {
log.Printf("[Scheduler] Failed to update task %s: %v", taskID, err)
}
}
func (s *Scheduler) calculateNextRun(schedule *Schedule) time.Time {
switch schedule.Type {
case "interval":
return time.Now().Add(time.Duration(schedule.Interval) * time.Second)
case "hourly":
return time.Now().Add(time.Hour).Truncate(time.Hour)
case "daily":
next := time.Now().Add(24 * time.Hour)
return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location())
case "weekly":
next := time.Now().Add(7 * 24 * time.Hour)
return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location())
case "monthly":
next := time.Now().AddDate(0, 1, 0)
return time.Date(next.Year(), next.Month(), 1, 9, 0, 0, 0, next.Location())
default:
return time.Now().Add(time.Hour)
}
}
func (s *Scheduler) checkAndExecute(ctx context.Context) {
tasks, err := s.taskRepo.GetScheduled(ctx)
if err != nil {
return
}
now := time.Now()
for _, task := range tasks {
if task.NextRunAt != nil && task.NextRunAt.Before(now) {
if task.Schedule != nil && task.Schedule.Enabled {
go s.executeScheduledTask(task.ID)
}
}
}
}
func (s *Scheduler) Schedule(taskID string, schedule Schedule) error {
ctx := context.Background()
task, err := s.taskRepo.GetByID(ctx, taskID)
if err != nil {
return err
}
task.Schedule = &schedule
task.Schedule.Enabled = true
task.Schedule.NextRun = s.calculateNextRun(&schedule)
task.NextRunAt = &task.Schedule.NextRun
task.Status = StatusScheduled
task.UpdatedAt = time.Now()
if err := s.taskRepo.Update(ctx, task); err != nil {
return err
}
return s.scheduleTask(task)
}
func (s *Scheduler) Cancel(taskID string) error {
s.mu.Lock()
defer s.mu.Unlock()
if entryID, exists := s.jobs[taskID]; exists {
s.cron.Remove(entryID)
delete(s.jobs, taskID)
}
ctx := context.Background()
task, err := s.taskRepo.GetByID(ctx, taskID)
if err != nil {
return err
}
if task.Schedule != nil {
task.Schedule.Enabled = false
}
task.Status = StatusCancelled
task.UpdatedAt = time.Now()
return s.taskRepo.Update(ctx, task)
}
func (s *Scheduler) Pause(taskID string) error {
s.mu.Lock()
defer s.mu.Unlock()
if entryID, exists := s.jobs[taskID]; exists {
s.cron.Remove(entryID)
delete(s.jobs, taskID)
}
ctx := context.Background()
task, err := s.taskRepo.GetByID(ctx, taskID)
if err != nil {
return err
}
if task.Schedule != nil {
task.Schedule.Enabled = false
}
task.UpdatedAt = time.Now()
return s.taskRepo.Update(ctx, task)
}
func (s *Scheduler) Resume(taskID string) error {
ctx := context.Background()
task, err := s.taskRepo.GetByID(ctx, taskID)
if err != nil {
return err
}
if task.Schedule != nil {
task.Schedule.Enabled = true
task.Schedule.NextRun = s.calculateNextRun(task.Schedule)
task.NextRunAt = &task.Schedule.NextRun
}
task.Status = StatusScheduled
task.UpdatedAt = time.Now()
if err := s.taskRepo.Update(ctx, task); err != nil {
return err
}
return s.scheduleTask(task)
}
func (s *Scheduler) GetScheduledTasks() []string {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]string, 0, len(s.jobs))
for taskID := range s.jobs {
result = append(result, taskID)
}
return result
}
func (s *Scheduler) IsRunning(taskID string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.running[taskID]
}

View File

@@ -0,0 +1,376 @@
package computer
import (
"time"
"github.com/gooseek/backend/internal/llm"
)
type TaskStatus string
const (
StatusPending TaskStatus = "pending"
StatusPlanning TaskStatus = "planning"
StatusExecuting TaskStatus = "executing"
StatusWaiting TaskStatus = "waiting_user"
StatusCompleted TaskStatus = "completed"
StatusFailed TaskStatus = "failed"
StatusCancelled TaskStatus = "cancelled"
StatusScheduled TaskStatus = "scheduled"
StatusPaused TaskStatus = "paused"
StatusCheckpoint TaskStatus = "checkpoint"
StatusLongRunning TaskStatus = "long_running"
)
type TaskType string
const (
TaskResearch TaskType = "research"
TaskCode TaskType = "code"
TaskAnalysis TaskType = "analysis"
TaskDesign TaskType = "design"
TaskDeploy TaskType = "deploy"
TaskMonitor TaskType = "monitor"
TaskReport TaskType = "report"
TaskCommunicate TaskType = "communicate"
TaskSchedule TaskType = "schedule"
TaskTransform TaskType = "transform"
TaskValidate TaskType = "validate"
)
type ComputerTask struct {
ID string `json:"id"`
UserID string `json:"userId"`
Query string `json:"query"`
Status TaskStatus `json:"status"`
Plan *TaskPlan `json:"plan,omitempty"`
SubTasks []SubTask `json:"subTasks,omitempty"`
Artifacts []Artifact `json:"artifacts,omitempty"`
Memory map[string]interface{} `json:"memory,omitempty"`
Progress int `json:"progress"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
Schedule *Schedule `json:"schedule,omitempty"`
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
RunCount int `json:"runCount"`
TotalCost float64 `json:"totalCost"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
DurationMode DurationMode `json:"durationMode"`
Checkpoint *Checkpoint `json:"checkpoint,omitempty"`
Checkpoints []Checkpoint `json:"checkpoints,omitempty"`
MaxDuration time.Duration `json:"maxDuration"`
EstimatedEnd *time.Time `json:"estimatedEnd,omitempty"`
Iterations int `json:"iterations"`
MaxIterations int `json:"maxIterations"`
PausedAt *time.Time `json:"pausedAt,omitempty"`
ResumedAt *time.Time `json:"resumedAt,omitempty"`
TotalRuntime time.Duration `json:"totalRuntime"`
HeartbeatAt *time.Time `json:"heartbeatAt,omitempty"`
Priority TaskPriority `json:"priority"`
ResourceLimits *ResourceLimits `json:"resourceLimits,omitempty"`
}
type DurationMode string
const (
DurationShort DurationMode = "short"
DurationMedium DurationMode = "medium"
DurationLong DurationMode = "long"
DurationExtended DurationMode = "extended"
DurationUnlimited DurationMode = "unlimited"
)
type TaskPriority string
const (
PriorityLow TaskPriority = "low"
PriorityNormal TaskPriority = "normal"
PriorityHigh TaskPriority = "high"
PriorityCritical TaskPriority = "critical"
)
type Checkpoint struct {
ID string `json:"id"`
TaskID string `json:"taskId"`
SubTaskIndex int `json:"subTaskIndex"`
WaveIndex int `json:"waveIndex"`
State map[string]interface{} `json:"state"`
Progress int `json:"progress"`
Artifacts []string `json:"artifacts"`
Memory map[string]interface{} `json:"memory"`
CreatedAt time.Time `json:"createdAt"`
RuntimeSoFar time.Duration `json:"runtimeSoFar"`
CostSoFar float64 `json:"costSoFar"`
Reason string `json:"reason"`
}
type ResourceLimits struct {
MaxCPU float64 `json:"maxCpu"`
MaxMemoryMB int `json:"maxMemoryMb"`
MaxDiskMB int `json:"maxDiskMb"`
MaxNetworkMbps int `json:"maxNetworkMbps"`
MaxCostPerHour float64 `json:"maxCostPerHour"`
MaxTotalCost float64 `json:"maxTotalCost"`
MaxConcurrent int `json:"maxConcurrent"`
IdleTimeoutMins int `json:"idleTimeoutMins"`
}
var DurationModeConfigs = map[DurationMode]struct {
MaxDuration time.Duration
CheckpointFreq time.Duration
HeartbeatFreq time.Duration
MaxIterations int
}{
DurationShort: {30 * time.Minute, 5 * time.Minute, 30 * time.Second, 10},
DurationMedium: {4 * time.Hour, 15 * time.Minute, time.Minute, 50},
DurationLong: {24 * time.Hour, 30 * time.Minute, 2 * time.Minute, 200},
DurationExtended: {7 * 24 * time.Hour, time.Hour, 5 * time.Minute, 1000},
DurationUnlimited: {365 * 24 * time.Hour, 4 * time.Hour, 10 * time.Minute, 0},
}
type SubTask struct {
ID string `json:"id"`
Type TaskType `json:"type"`
Description string `json:"description"`
Dependencies []string `json:"dependencies,omitempty"`
ModelID string `json:"modelId,omitempty"`
RequiredCaps []llm.ModelCapability `json:"requiredCaps,omitempty"`
Input map[string]interface{} `json:"input,omitempty"`
Output map[string]interface{} `json:"output,omitempty"`
Status TaskStatus `json:"status"`
Progress int `json:"progress"`
Error string `json:"error,omitempty"`
Cost float64 `json:"cost"`
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
Retries int `json:"retries"`
MaxRetries int `json:"maxRetries"`
}
type TaskPlan struct {
Query string `json:"query"`
Summary string `json:"summary"`
SubTasks []SubTask `json:"subTasks"`
ExecutionOrder [][]string `json:"executionOrder"`
EstimatedCost float64 `json:"estimatedCost"`
EstimatedTime int `json:"estimatedTimeSeconds"`
}
type Artifact struct {
ID string `json:"id"`
TaskID string `json:"taskId"`
Type string `json:"type"`
Name string `json:"name"`
Content []byte `json:"-"`
URL string `json:"url,omitempty"`
Size int64 `json:"size"`
MimeType string `json:"mimeType,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
type Schedule struct {
Type string `json:"type"`
CronExpr string `json:"cronExpr,omitempty"`
Interval int `json:"intervalSeconds,omitempty"`
NextRun time.Time `json:"nextRun"`
MaxRuns int `json:"maxRuns"`
RunCount int `json:"runCount"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
Enabled bool `json:"enabled"`
DurationMode DurationMode `json:"durationMode,omitempty"`
RetryOnFail bool `json:"retryOnFail"`
MaxRetries int `json:"maxRetries"`
RetryDelay time.Duration `json:"retryDelay"`
Timezone string `json:"timezone,omitempty"`
WindowStart string `json:"windowStart,omitempty"`
WindowEnd string `json:"windowEnd,omitempty"`
Conditions []Condition `json:"conditions,omitempty"`
}
type Condition struct {
Type string `json:"type"`
Field string `json:"field"`
Operator string `json:"operator"`
Value interface{} `json:"value"`
Params map[string]interface{} `json:"params,omitempty"`
}
const (
ScheduleOnce = "once"
ScheduleInterval = "interval"
ScheduleCron = "cron"
ScheduleHourly = "hourly"
ScheduleDaily = "daily"
ScheduleWeekly = "weekly"
ScheduleMonthly = "monthly"
ScheduleQuarterly = "quarterly"
ScheduleYearly = "yearly"
ScheduleContinuous = "continuous"
ScheduleOnCondition = "on_condition"
)
type TaskEvent struct {
Type string `json:"type"`
TaskID string `json:"taskId"`
SubTaskID string `json:"subTaskId,omitempty"`
Status TaskStatus `json:"status,omitempty"`
Progress int `json:"progress,omitempty"`
Message string `json:"message,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type ExecuteOptions struct {
Async bool `json:"async"`
MaxCost float64 `json:"maxCost"`
Timeout int `json:"timeoutSeconds"`
EnableSandbox bool `json:"enableSandbox"`
Schedule *Schedule `json:"schedule,omitempty"`
Context map[string]interface{} `json:"context,omitempty"`
DurationMode DurationMode `json:"durationMode,omitempty"`
Priority TaskPriority `json:"priority,omitempty"`
ResourceLimits *ResourceLimits `json:"resourceLimits,omitempty"`
ResumeFromID string `json:"resumeFromId,omitempty"`
EnableBrowser bool `json:"enableBrowser"`
BrowserOptions *BrowserOptions `json:"browserOptions,omitempty"`
NotifyOnEvents []string `json:"notifyOnEvents,omitempty"`
WebhookURL string `json:"webhookUrl,omitempty"`
Tags []string `json:"tags,omitempty"`
}
type BrowserOptions struct {
Headless bool `json:"headless"`
UserAgent string `json:"userAgent,omitempty"`
Viewport *Viewport `json:"viewport,omitempty"`
ProxyURL string `json:"proxyUrl,omitempty"`
Timeout int `json:"timeout"`
Screenshots bool `json:"screenshots"`
RecordVideo bool `json:"recordVideo"`
BlockAds bool `json:"blockAds"`
AcceptCookies bool `json:"acceptCookies"`
}
type Viewport struct {
Width int `json:"width"`
Height int `json:"height"`
}
type ExecutionResult struct {
TaskID string
SubTaskID string
Output map[string]interface{}
Artifacts []Artifact
Duration time.Duration
Cost float64
Error error
}
type SandboxResult struct {
Stdout string
Stderr string
ExitCode int
Files map[string][]byte
Duration time.Duration
}
type MemoryEntry struct {
ID string `json:"id"`
UserID string `json:"userId"`
TaskID string `json:"taskId,omitempty"`
Key string `json:"key"`
Value interface{} `json:"value"`
Type string `json:"type"`
Tags []string `json:"tags,omitempty"`
CreatedAt time.Time `json:"createdAt"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
}
const (
EventTaskCreated = "task_created"
EventTaskStarted = "task_started"
EventTaskProgress = "task_progress"
EventTaskCompleted = "task_completed"
EventTaskFailed = "task_failed"
EventSubTaskStart = "subtask_start"
EventSubTaskDone = "subtask_done"
EventSubTaskFail = "subtask_fail"
EventArtifact = "artifact"
EventMessage = "message"
EventUserInput = "user_input_required"
EventCheckpoint = "checkpoint"
EventCheckpointSaved = "checkpoint_saved"
EventResumed = "resumed"
EventPaused = "paused"
EventHeartbeat = "heartbeat"
EventIteration = "iteration"
EventBrowserAction = "browser_action"
EventScreenshot = "screenshot"
EventResourceAlert = "resource_alert"
EventScheduleUpdate = "schedule_update"
)
type BrowserAction struct {
ID string `json:"id"`
Type BrowserActionType `json:"type"`
Selector string `json:"selector,omitempty"`
URL string `json:"url,omitempty"`
Value string `json:"value,omitempty"`
Options map[string]interface{} `json:"options,omitempty"`
Screenshot bool `json:"screenshot"`
WaitAfter int `json:"waitAfterMs"`
Timeout int `json:"timeoutMs"`
Result *BrowserActionResult `json:"result,omitempty"`
}
type BrowserActionType string
const (
BrowserNavigate BrowserActionType = "navigate"
BrowserClick BrowserActionType = "click"
BrowserType BrowserActionType = "type"
BrowserScroll BrowserActionType = "scroll"
BrowserScreenshot BrowserActionType = "screenshot"
BrowserWait BrowserActionType = "wait"
BrowserWaitSelector BrowserActionType = "wait_selector"
BrowserExtract BrowserActionType = "extract"
BrowserEval BrowserActionType = "eval"
BrowserSelect BrowserActionType = "select"
BrowserUpload BrowserActionType = "upload"
BrowserDownload BrowserActionType = "download"
BrowserPDF BrowserActionType = "pdf"
BrowserClose BrowserActionType = "close"
)
type BrowserActionResult struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Screenshot string `json:"screenshot,omitempty"`
Error string `json:"error,omitempty"`
Duration time.Duration `json:"duration"`
PageTitle string `json:"pageTitle,omitempty"`
PageURL string `json:"pageUrl,omitempty"`
Cookies []map[string]string `json:"cookies,omitempty"`
LocalStorage map[string]string `json:"localStorage,omitempty"`
}
const (
ArtifactTypeFile = "file"
ArtifactTypeCode = "code"
ArtifactTypeReport = "report"
ArtifactTypeDeployment = "deployment"
ArtifactTypeImage = "image"
ArtifactTypeData = "data"
)
const (
MemoryTypeFact = "fact"
MemoryTypePreference = "preference"
MemoryTypeContext = "context"
MemoryTypeResult = "result"
)

View File

@@ -0,0 +1,97 @@
package db
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"strings"
"time"
)
type ArticleSummary struct {
ID int64 `json:"id"`
URLHash string `json:"urlHash"`
URL string `json:"url"`
Events []string `json:"events"`
CreatedAt time.Time `json:"createdAt"`
ExpiresAt time.Time `json:"expiresAt"`
}
type ArticleSummaryRepository struct {
db *PostgresDB
}
func NewArticleSummaryRepository(db *PostgresDB) *ArticleSummaryRepository {
return &ArticleSummaryRepository{db: db}
}
func (r *ArticleSummaryRepository) hashURL(url string) string {
normalized := strings.TrimSpace(url)
normalized = strings.TrimSuffix(normalized, "/")
normalized = strings.TrimPrefix(normalized, "https://")
normalized = strings.TrimPrefix(normalized, "http://")
normalized = strings.TrimPrefix(normalized, "www.")
hash := sha256.Sum256([]byte(normalized))
return hex.EncodeToString(hash[:])
}
func (r *ArticleSummaryRepository) GetByURL(ctx context.Context, url string) (*ArticleSummary, error) {
urlHash := r.hashURL(url)
query := `
SELECT id, url_hash, url, events, created_at, expires_at
FROM article_summaries
WHERE url_hash = $1 AND expires_at > NOW()
`
var a ArticleSummary
var eventsJSON []byte
err := r.db.db.QueryRowContext(ctx, query, urlHash).Scan(
&a.ID, &a.URLHash, &a.URL, &eventsJSON, &a.CreatedAt, &a.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(eventsJSON, &a.Events)
return &a, nil
}
func (r *ArticleSummaryRepository) Save(ctx context.Context, url string, events []string, ttl time.Duration) error {
urlHash := r.hashURL(url)
eventsJSON, _ := json.Marshal(events)
expiresAt := time.Now().Add(ttl)
query := `
INSERT INTO article_summaries (url_hash, url, events, expires_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (url_hash)
DO UPDATE SET
events = EXCLUDED.events,
expires_at = EXCLUDED.expires_at
`
_, err := r.db.db.ExecContext(ctx, query, urlHash, url, eventsJSON, expiresAt)
return err
}
func (r *ArticleSummaryRepository) Delete(ctx context.Context, url string) error {
urlHash := r.hashURL(url)
_, err := r.db.db.ExecContext(ctx, "DELETE FROM article_summaries WHERE url_hash = $1", urlHash)
return err
}
func (r *ArticleSummaryRepository) CleanupExpired(ctx context.Context) (int64, error) {
result, err := r.db.db.ExecContext(ctx, "DELETE FROM article_summaries WHERE expires_at < NOW()")
if err != nil {
return 0, err
}
return result.RowsAffected()
}

View File

@@ -0,0 +1,204 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type Collection struct {
ID string `json:"id"`
UserID string `json:"userId"`
Name string `json:"name"`
Description string `json:"description"`
IsPublic bool `json:"isPublic"`
ContextEnabled bool `json:"contextEnabled"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Items []CollectionItem `json:"items,omitempty"`
ItemCount int `json:"itemCount,omitempty"`
}
type CollectionItem struct {
ID string `json:"id"`
CollectionID string `json:"collectionId"`
ItemType string `json:"itemType"`
Title string `json:"title"`
Content string `json:"content"`
URL string `json:"url"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"createdAt"`
SortOrder int `json:"sortOrder"`
}
type CollectionRepository struct {
db *PostgresDB
}
func NewCollectionRepository(db *PostgresDB) *CollectionRepository {
return &CollectionRepository{db: db}
}
func (r *CollectionRepository) Create(ctx context.Context, c *Collection) error {
query := `
INSERT INTO collections (user_id, name, description, is_public, context_enabled)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, created_at, updated_at
`
return r.db.db.QueryRowContext(ctx, query,
c.UserID, c.Name, c.Description, c.IsPublic, c.ContextEnabled,
).Scan(&c.ID, &c.CreatedAt, &c.UpdatedAt)
}
func (r *CollectionRepository) GetByID(ctx context.Context, id string) (*Collection, error) {
query := `
SELECT id, user_id, name, description, is_public, context_enabled, created_at, updated_at,
(SELECT COUNT(*) FROM collection_items WHERE collection_id = collections.id) as item_count
FROM collections
WHERE id = $1
`
var c Collection
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
&c.ID, &c.UserID, &c.Name, &c.Description, &c.IsPublic,
&c.ContextEnabled, &c.CreatedAt, &c.UpdatedAt, &c.ItemCount,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &c, nil
}
func (r *CollectionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*Collection, error) {
query := `
SELECT id, user_id, name, description, is_public, context_enabled, created_at, updated_at,
(SELECT COUNT(*) FROM collection_items WHERE collection_id = collections.id) as item_count
FROM collections
WHERE user_id = $1
ORDER BY updated_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var collections []*Collection
for rows.Next() {
var c Collection
if err := rows.Scan(
&c.ID, &c.UserID, &c.Name, &c.Description, &c.IsPublic,
&c.ContextEnabled, &c.CreatedAt, &c.UpdatedAt, &c.ItemCount,
); err != nil {
return nil, err
}
collections = append(collections, &c)
}
return collections, nil
}
func (r *CollectionRepository) Update(ctx context.Context, c *Collection) error {
query := `
UPDATE collections
SET name = $2, description = $3, is_public = $4, context_enabled = $5, updated_at = NOW()
WHERE id = $1
`
_, err := r.db.db.ExecContext(ctx, query,
c.ID, c.Name, c.Description, c.IsPublic, c.ContextEnabled,
)
return err
}
func (r *CollectionRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collections WHERE id = $1", id)
return err
}
func (r *CollectionRepository) AddItem(ctx context.Context, item *CollectionItem) error {
metadataJSON, _ := json.Marshal(item.Metadata)
query := `
INSERT INTO collection_items (collection_id, item_type, title, content, url, metadata, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, COALESCE((SELECT MAX(sort_order) + 1 FROM collection_items WHERE collection_id = $1), 0))
RETURNING id, created_at, sort_order
`
return r.db.db.QueryRowContext(ctx, query,
item.CollectionID, item.ItemType, item.Title, item.Content, item.URL, metadataJSON,
).Scan(&item.ID, &item.CreatedAt, &item.SortOrder)
}
func (r *CollectionRepository) GetItems(ctx context.Context, collectionID string) ([]CollectionItem, error) {
query := `
SELECT id, collection_id, item_type, title, content, url, metadata, created_at, sort_order
FROM collection_items
WHERE collection_id = $1
ORDER BY sort_order ASC
`
rows, err := r.db.db.QueryContext(ctx, query, collectionID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []CollectionItem
for rows.Next() {
var item CollectionItem
var metadataJSON []byte
if err := rows.Scan(
&item.ID, &item.CollectionID, &item.ItemType, &item.Title,
&item.Content, &item.URL, &metadataJSON, &item.CreatedAt, &item.SortOrder,
); err != nil {
return nil, err
}
json.Unmarshal(metadataJSON, &item.Metadata)
items = append(items, item)
}
return items, nil
}
func (r *CollectionRepository) RemoveItem(ctx context.Context, itemID string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collection_items WHERE id = $1", itemID)
return err
}
func (r *CollectionRepository) GetCollectionContext(ctx context.Context, collectionID string) (string, error) {
items, err := r.GetItems(ctx, collectionID)
if err != nil {
return "", err
}
var context string
for _, item := range items {
switch item.ItemType {
case "search":
context += "Previous search: " + item.Title + "\n"
if item.Content != "" {
context += "Summary: " + item.Content + "\n"
}
case "note":
context += "User note: " + item.Content + "\n"
case "url":
context += "Saved URL: " + item.URL + " - " + item.Title + "\n"
case "file":
context += "Uploaded file: " + item.Title + "\n"
if item.Content != "" {
context += "Content: " + item.Content + "\n"
}
}
context += "\n"
}
return context, nil
}

View File

@@ -0,0 +1,322 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/gooseek/backend/internal/computer"
)
type ComputerArtifactRepo struct {
db *sql.DB
}
func NewComputerArtifactRepo(db *sql.DB) *ComputerArtifactRepo {
return &ComputerArtifactRepo{db: db}
}
func (r *ComputerArtifactRepo) Migrate() error {
query := `
CREATE TABLE IF NOT EXISTS computer_artifacts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
task_id UUID NOT NULL,
type VARCHAR(50) NOT NULL,
name VARCHAR(255),
content BYTEA,
url TEXT,
size BIGINT DEFAULT 0,
mime_type VARCHAR(100),
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_task_id ON computer_artifacts(task_id);
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_type ON computer_artifacts(type);
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_created ON computer_artifacts(created_at DESC);
`
_, err := r.db.Exec(query)
return err
}
func (r *ComputerArtifactRepo) Create(ctx context.Context, artifact *computer.Artifact) error {
metadataJSON, _ := json.Marshal(artifact.Metadata)
query := `
INSERT INTO computer_artifacts (id, task_id, type, name, content, url, size, mime_type, metadata, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
_, err := r.db.ExecContext(ctx, query,
artifact.ID,
artifact.TaskID,
artifact.Type,
artifact.Name,
artifact.Content,
artifact.URL,
artifact.Size,
artifact.MimeType,
metadataJSON,
artifact.CreatedAt,
)
return err
}
func (r *ComputerArtifactRepo) GetByID(ctx context.Context, id string) (*computer.Artifact, error) {
query := `
SELECT id, task_id, type, name, content, url, size, mime_type, metadata, created_at
FROM computer_artifacts
WHERE id = $1
`
var artifact computer.Artifact
var content []byte
var url, mimeType sql.NullString
var metadataJSON []byte
err := r.db.QueryRowContext(ctx, query, id).Scan(
&artifact.ID,
&artifact.TaskID,
&artifact.Type,
&artifact.Name,
&content,
&url,
&artifact.Size,
&mimeType,
&metadataJSON,
&artifact.CreatedAt,
)
if err != nil {
return nil, err
}
artifact.Content = content
if url.Valid {
artifact.URL = url.String
}
if mimeType.Valid {
artifact.MimeType = mimeType.String
}
if len(metadataJSON) > 0 {
json.Unmarshal(metadataJSON, &artifact.Metadata)
}
return &artifact, nil
}
func (r *ComputerArtifactRepo) GetByTaskID(ctx context.Context, taskID string) ([]computer.Artifact, error) {
query := `
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
FROM computer_artifacts
WHERE task_id = $1
ORDER BY created_at ASC
`
rows, err := r.db.QueryContext(ctx, query, taskID)
if err != nil {
return nil, err
}
defer rows.Close()
var artifacts []computer.Artifact
for rows.Next() {
var artifact computer.Artifact
var url, mimeType sql.NullString
var metadataJSON []byte
err := rows.Scan(
&artifact.ID,
&artifact.TaskID,
&artifact.Type,
&artifact.Name,
&url,
&artifact.Size,
&mimeType,
&metadataJSON,
&artifact.CreatedAt,
)
if err != nil {
continue
}
if url.Valid {
artifact.URL = url.String
}
if mimeType.Valid {
artifact.MimeType = mimeType.String
}
if len(metadataJSON) > 0 {
json.Unmarshal(metadataJSON, &artifact.Metadata)
}
artifacts = append(artifacts, artifact)
}
return artifacts, nil
}
func (r *ComputerArtifactRepo) GetByType(ctx context.Context, taskID, artifactType string) ([]computer.Artifact, error) {
query := `
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
FROM computer_artifacts
WHERE task_id = $1 AND type = $2
ORDER BY created_at ASC
`
rows, err := r.db.QueryContext(ctx, query, taskID, artifactType)
if err != nil {
return nil, err
}
defer rows.Close()
var artifacts []computer.Artifact
for rows.Next() {
var artifact computer.Artifact
var url, mimeType sql.NullString
var metadataJSON []byte
err := rows.Scan(
&artifact.ID,
&artifact.TaskID,
&artifact.Type,
&artifact.Name,
&url,
&artifact.Size,
&mimeType,
&metadataJSON,
&artifact.CreatedAt,
)
if err != nil {
continue
}
if url.Valid {
artifact.URL = url.String
}
if mimeType.Valid {
artifact.MimeType = mimeType.String
}
if len(metadataJSON) > 0 {
json.Unmarshal(metadataJSON, &artifact.Metadata)
}
artifacts = append(artifacts, artifact)
}
return artifacts, nil
}
func (r *ComputerArtifactRepo) GetContent(ctx context.Context, id string) ([]byte, error) {
query := `SELECT content FROM computer_artifacts WHERE id = $1`
var content []byte
err := r.db.QueryRowContext(ctx, query, id).Scan(&content)
return content, err
}
func (r *ComputerArtifactRepo) UpdateURL(ctx context.Context, id, url string) error {
query := `UPDATE computer_artifacts SET url = $1 WHERE id = $2`
_, err := r.db.ExecContext(ctx, query, url, id)
return err
}
func (r *ComputerArtifactRepo) Delete(ctx context.Context, id string) error {
query := `DELETE FROM computer_artifacts WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, id)
return err
}
func (r *ComputerArtifactRepo) DeleteByTaskID(ctx context.Context, taskID string) error {
query := `DELETE FROM computer_artifacts WHERE task_id = $1`
_, err := r.db.ExecContext(ctx, query, taskID)
return err
}
func (r *ComputerArtifactRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
query := `
DELETE FROM computer_artifacts
WHERE created_at < NOW() - INTERVAL '1 day' * $1
`
result, err := r.db.ExecContext(ctx, query, days)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (r *ComputerArtifactRepo) GetTotalSize(ctx context.Context, taskID string) (int64, error) {
query := `SELECT COALESCE(SUM(size), 0) FROM computer_artifacts WHERE task_id = $1`
var size int64
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&size)
return size, err
}
func (r *ComputerArtifactRepo) Count(ctx context.Context, taskID string) (int64, error) {
query := `SELECT COUNT(*) FROM computer_artifacts WHERE task_id = $1`
var count int64
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&count)
return count, err
}
type ArtifactSummary struct {
ID string `json:"id"`
TaskID string `json:"taskId"`
Type string `json:"type"`
Name string `json:"name"`
URL string `json:"url"`
Size int64 `json:"size"`
MimeType string `json:"mimeType"`
CreatedAt time.Time `json:"createdAt"`
}
func (r *ComputerArtifactRepo) GetSummaries(ctx context.Context, taskID string) ([]ArtifactSummary, error) {
query := `
SELECT id, task_id, type, name, url, size, mime_type, created_at
FROM computer_artifacts
WHERE task_id = $1
ORDER BY created_at ASC
`
rows, err := r.db.QueryContext(ctx, query, taskID)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []ArtifactSummary
for rows.Next() {
var s ArtifactSummary
var url, mimeType sql.NullString
err := rows.Scan(
&s.ID,
&s.TaskID,
&s.Type,
&s.Name,
&url,
&s.Size,
&mimeType,
&s.CreatedAt,
)
if err != nil {
continue
}
if url.Valid {
s.URL = url.String
}
if mimeType.Valid {
s.MimeType = mimeType.String
}
summaries = append(summaries, s)
}
return summaries, nil
}

View File

@@ -0,0 +1,306 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"strings"
"time"
"github.com/gooseek/backend/internal/computer"
)
type ComputerMemoryRepo struct {
db *sql.DB
}
func NewComputerMemoryRepo(db *sql.DB) *ComputerMemoryRepo {
return &ComputerMemoryRepo{db: db}
}
func (r *ComputerMemoryRepo) Migrate() error {
query := `
CREATE TABLE IF NOT EXISTS computer_memory (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
task_id UUID,
key VARCHAR(255) NOT NULL,
value JSONB NOT NULL,
type VARCHAR(50),
tags TEXT[],
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_computer_memory_user_id ON computer_memory(user_id);
CREATE INDEX IF NOT EXISTS idx_computer_memory_task_id ON computer_memory(task_id);
CREATE INDEX IF NOT EXISTS idx_computer_memory_type ON computer_memory(type);
CREATE INDEX IF NOT EXISTS idx_computer_memory_expires ON computer_memory(expires_at) WHERE expires_at IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_computer_memory_key ON computer_memory(key);
`
_, err := r.db.Exec(query)
return err
}
func (r *ComputerMemoryRepo) Store(ctx context.Context, entry *computer.MemoryEntry) error {
valueJSON, err := json.Marshal(entry.Value)
if err != nil {
return err
}
query := `
INSERT INTO computer_memory (id, user_id, task_id, key, value, type, tags, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (id) DO UPDATE SET
value = EXCLUDED.value,
type = EXCLUDED.type,
tags = EXCLUDED.tags,
expires_at = EXCLUDED.expires_at
`
var taskID interface{}
if entry.TaskID != "" {
taskID = entry.TaskID
}
_, err = r.db.ExecContext(ctx, query,
entry.ID,
entry.UserID,
taskID,
entry.Key,
valueJSON,
entry.Type,
entry.Tags,
entry.CreatedAt,
entry.ExpiresAt,
)
return err
}
func (r *ComputerMemoryRepo) GetByUser(ctx context.Context, userID string, limit int) ([]computer.MemoryEntry, error) {
query := `
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
FROM computer_memory
WHERE user_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY created_at DESC
LIMIT $2
`
rows, err := r.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return r.scanEntries(rows)
}
func (r *ComputerMemoryRepo) GetByTask(ctx context.Context, taskID string) ([]computer.MemoryEntry, error) {
query := `
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
FROM computer_memory
WHERE task_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY created_at ASC
`
rows, err := r.db.QueryContext(ctx, query, taskID)
if err != nil {
return nil, err
}
defer rows.Close()
return r.scanEntries(rows)
}
func (r *ComputerMemoryRepo) Search(ctx context.Context, userID, query string, limit int) ([]computer.MemoryEntry, error) {
searchTerms := strings.Fields(strings.ToLower(query))
if len(searchTerms) == 0 {
return r.GetByUser(ctx, userID, limit)
}
likePatterns := make([]string, len(searchTerms))
args := make([]interface{}, len(searchTerms)+2)
args[0] = userID
for i, term := range searchTerms {
likePatterns[i] = "%" + term + "%"
args[i+1] = likePatterns[i]
}
args[len(args)-1] = limit
var conditions []string
for i := range searchTerms {
conditions = append(conditions, "(LOWER(key) LIKE $"+string(rune('2'+i))+" OR LOWER(value::text) LIKE $"+string(rune('2'+i))+")")
}
sqlQuery := `
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
FROM computer_memory
WHERE user_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
AND (` + strings.Join(conditions, " OR ") + `)
ORDER BY created_at DESC
LIMIT $` + string(rune('2'+len(searchTerms)))
rows, err := r.db.QueryContext(ctx, sqlQuery, args...)
if err != nil {
return r.GetByUser(ctx, userID, limit)
}
defer rows.Close()
return r.scanEntries(rows)
}
func (r *ComputerMemoryRepo) GetByType(ctx context.Context, userID, memType string, limit int) ([]computer.MemoryEntry, error) {
query := `
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
FROM computer_memory
WHERE user_id = $1 AND type = $2
AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY created_at DESC
LIMIT $3
`
rows, err := r.db.QueryContext(ctx, query, userID, memType, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return r.scanEntries(rows)
}
func (r *ComputerMemoryRepo) GetByKey(ctx context.Context, userID, key string) (*computer.MemoryEntry, error) {
query := `
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
FROM computer_memory
WHERE user_id = $1 AND key = $2
AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY created_at DESC
LIMIT 1
`
var entry computer.MemoryEntry
var valueJSON []byte
var taskID sql.NullString
var expiresAt sql.NullTime
var tags []string
err := r.db.QueryRowContext(ctx, query, userID, key).Scan(
&entry.ID,
&entry.UserID,
&taskID,
&entry.Key,
&valueJSON,
&entry.Type,
&tags,
&entry.CreatedAt,
&expiresAt,
)
if err != nil {
return nil, err
}
if taskID.Valid {
entry.TaskID = taskID.String
}
if expiresAt.Valid {
entry.ExpiresAt = &expiresAt.Time
}
entry.Tags = tags
json.Unmarshal(valueJSON, &entry.Value)
return &entry, nil
}
func (r *ComputerMemoryRepo) Delete(ctx context.Context, id string) error {
query := `DELETE FROM computer_memory WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, id)
return err
}
func (r *ComputerMemoryRepo) DeleteByUser(ctx context.Context, userID string) error {
query := `DELETE FROM computer_memory WHERE user_id = $1`
_, err := r.db.ExecContext(ctx, query, userID)
return err
}
func (r *ComputerMemoryRepo) DeleteByTask(ctx context.Context, taskID string) error {
query := `DELETE FROM computer_memory WHERE task_id = $1`
_, err := r.db.ExecContext(ctx, query, taskID)
return err
}
func (r *ComputerMemoryRepo) DeleteExpired(ctx context.Context) (int64, error) {
query := `DELETE FROM computer_memory WHERE expires_at IS NOT NULL AND expires_at < NOW()`
result, err := r.db.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (r *ComputerMemoryRepo) scanEntries(rows *sql.Rows) ([]computer.MemoryEntry, error) {
var entries []computer.MemoryEntry
for rows.Next() {
var entry computer.MemoryEntry
var valueJSON []byte
var taskID sql.NullString
var expiresAt sql.NullTime
var tags []string
err := rows.Scan(
&entry.ID,
&entry.UserID,
&taskID,
&entry.Key,
&valueJSON,
&entry.Type,
&tags,
&entry.CreatedAt,
&expiresAt,
)
if err != nil {
continue
}
if taskID.Valid {
entry.TaskID = taskID.String
}
if expiresAt.Valid {
entry.ExpiresAt = &expiresAt.Time
}
entry.Tags = tags
json.Unmarshal(valueJSON, &entry.Value)
entries = append(entries, entry)
}
return entries, nil
}
func (r *ComputerMemoryRepo) Count(ctx context.Context, userID string) (int64, error) {
query := `
SELECT COUNT(*)
FROM computer_memory
WHERE user_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
`
var count int64
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
return count, err
}
func (r *ComputerMemoryRepo) UpdateExpiry(ctx context.Context, id string, expiresAt time.Time) error {
query := `UPDATE computer_memory SET expires_at = $1 WHERE id = $2`
_, err := r.db.ExecContext(ctx, query, expiresAt, id)
return err
}

View File

@@ -0,0 +1,411 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/gooseek/backend/internal/computer"
)
type ComputerTaskRepo struct {
db *sql.DB
}
func NewComputerTaskRepo(db *sql.DB) *ComputerTaskRepo {
return &ComputerTaskRepo{db: db}
}
func (r *ComputerTaskRepo) Migrate() error {
query := `
CREATE TABLE IF NOT EXISTS computer_tasks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
query TEXT NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
plan JSONB,
sub_tasks JSONB,
artifacts JSONB,
memory JSONB,
progress INT DEFAULT 0,
message TEXT,
error TEXT,
schedule JSONB,
next_run_at TIMESTAMPTZ,
run_count INT DEFAULT 0,
total_cost DECIMAL(10,6) DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_computer_tasks_user_id ON computer_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_computer_tasks_status ON computer_tasks(status);
CREATE INDEX IF NOT EXISTS idx_computer_tasks_next_run ON computer_tasks(next_run_at) WHERE next_run_at IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_computer_tasks_created ON computer_tasks(created_at DESC);
`
_, err := r.db.Exec(query)
return err
}
func (r *ComputerTaskRepo) Create(ctx context.Context, task *computer.ComputerTask) error {
planJSON, _ := json.Marshal(task.Plan)
subTasksJSON, _ := json.Marshal(task.SubTasks)
artifactsJSON, _ := json.Marshal(task.Artifacts)
memoryJSON, _ := json.Marshal(task.Memory)
scheduleJSON, _ := json.Marshal(task.Schedule)
query := `
INSERT INTO computer_tasks (
id, user_id, query, status, plan, sub_tasks, artifacts, memory,
progress, message, error, schedule, next_run_at, run_count, total_cost,
created_at, updated_at, completed_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
)
`
_, err := r.db.ExecContext(ctx, query,
task.ID,
task.UserID,
task.Query,
task.Status,
planJSON,
subTasksJSON,
artifactsJSON,
memoryJSON,
task.Progress,
task.Message,
task.Error,
scheduleJSON,
task.NextRunAt,
task.RunCount,
task.TotalCost,
task.CreatedAt,
task.UpdatedAt,
task.CompletedAt,
)
return err
}
func (r *ComputerTaskRepo) Update(ctx context.Context, task *computer.ComputerTask) error {
planJSON, _ := json.Marshal(task.Plan)
subTasksJSON, _ := json.Marshal(task.SubTasks)
artifactsJSON, _ := json.Marshal(task.Artifacts)
memoryJSON, _ := json.Marshal(task.Memory)
scheduleJSON, _ := json.Marshal(task.Schedule)
query := `
UPDATE computer_tasks SET
status = $1,
plan = $2,
sub_tasks = $3,
artifacts = $4,
memory = $5,
progress = $6,
message = $7,
error = $8,
schedule = $9,
next_run_at = $10,
run_count = $11,
total_cost = $12,
updated_at = $13,
completed_at = $14
WHERE id = $15
`
_, err := r.db.ExecContext(ctx, query,
task.Status,
planJSON,
subTasksJSON,
artifactsJSON,
memoryJSON,
task.Progress,
task.Message,
task.Error,
scheduleJSON,
task.NextRunAt,
task.RunCount,
task.TotalCost,
time.Now(),
task.CompletedAt,
task.ID,
)
return err
}
func (r *ComputerTaskRepo) GetByID(ctx context.Context, id string) (*computer.ComputerTask, error) {
query := `
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
progress, message, error, schedule, next_run_at, run_count, total_cost,
created_at, updated_at, completed_at
FROM computer_tasks
WHERE id = $1
`
var task computer.ComputerTask
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
var message, errStr sql.NullString
var nextRunAt, completedAt sql.NullTime
err := r.db.QueryRowContext(ctx, query, id).Scan(
&task.ID,
&task.UserID,
&task.Query,
&task.Status,
&planJSON,
&subTasksJSON,
&artifactsJSON,
&memoryJSON,
&task.Progress,
&message,
&errStr,
&scheduleJSON,
&nextRunAt,
&task.RunCount,
&task.TotalCost,
&task.CreatedAt,
&task.UpdatedAt,
&completedAt,
)
if err != nil {
return nil, err
}
if len(planJSON) > 0 {
json.Unmarshal(planJSON, &task.Plan)
}
if len(subTasksJSON) > 0 {
json.Unmarshal(subTasksJSON, &task.SubTasks)
}
if len(artifactsJSON) > 0 {
json.Unmarshal(artifactsJSON, &task.Artifacts)
}
if len(memoryJSON) > 0 {
json.Unmarshal(memoryJSON, &task.Memory)
}
if len(scheduleJSON) > 0 {
json.Unmarshal(scheduleJSON, &task.Schedule)
}
if message.Valid {
task.Message = message.String
}
if errStr.Valid {
task.Error = errStr.String
}
if nextRunAt.Valid {
task.NextRunAt = &nextRunAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
return &task, nil
}
func (r *ComputerTaskRepo) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]computer.ComputerTask, error) {
query := `
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
progress, message, error, schedule, next_run_at, run_count, total_cost,
created_at, updated_at, completed_at
FROM computer_tasks
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []computer.ComputerTask
for rows.Next() {
var task computer.ComputerTask
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
var message, errStr sql.NullString
var nextRunAt, completedAt sql.NullTime
err := rows.Scan(
&task.ID,
&task.UserID,
&task.Query,
&task.Status,
&planJSON,
&subTasksJSON,
&artifactsJSON,
&memoryJSON,
&task.Progress,
&message,
&errStr,
&scheduleJSON,
&nextRunAt,
&task.RunCount,
&task.TotalCost,
&task.CreatedAt,
&task.UpdatedAt,
&completedAt,
)
if err != nil {
continue
}
if len(planJSON) > 0 {
json.Unmarshal(planJSON, &task.Plan)
}
if len(subTasksJSON) > 0 {
json.Unmarshal(subTasksJSON, &task.SubTasks)
}
if len(artifactsJSON) > 0 {
json.Unmarshal(artifactsJSON, &task.Artifacts)
}
if len(memoryJSON) > 0 {
json.Unmarshal(memoryJSON, &task.Memory)
}
if len(scheduleJSON) > 0 {
json.Unmarshal(scheduleJSON, &task.Schedule)
}
if message.Valid {
task.Message = message.String
}
if errStr.Valid {
task.Error = errStr.String
}
if nextRunAt.Valid {
task.NextRunAt = &nextRunAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
tasks = append(tasks, task)
}
return tasks, nil
}
func (r *ComputerTaskRepo) GetScheduled(ctx context.Context) ([]computer.ComputerTask, error) {
query := `
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
progress, message, error, schedule, next_run_at, run_count, total_cost,
created_at, updated_at, completed_at
FROM computer_tasks
WHERE status = 'scheduled' AND schedule IS NOT NULL
ORDER BY next_run_at ASC
`
rows, err := r.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []computer.ComputerTask
for rows.Next() {
var task computer.ComputerTask
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
var message, errStr sql.NullString
var nextRunAt, completedAt sql.NullTime
err := rows.Scan(
&task.ID,
&task.UserID,
&task.Query,
&task.Status,
&planJSON,
&subTasksJSON,
&artifactsJSON,
&memoryJSON,
&task.Progress,
&message,
&errStr,
&scheduleJSON,
&nextRunAt,
&task.RunCount,
&task.TotalCost,
&task.CreatedAt,
&task.UpdatedAt,
&completedAt,
)
if err != nil {
continue
}
if len(planJSON) > 0 {
json.Unmarshal(planJSON, &task.Plan)
}
if len(subTasksJSON) > 0 {
json.Unmarshal(subTasksJSON, &task.SubTasks)
}
if len(artifactsJSON) > 0 {
json.Unmarshal(artifactsJSON, &task.Artifacts)
}
if len(memoryJSON) > 0 {
json.Unmarshal(memoryJSON, &task.Memory)
}
if len(scheduleJSON) > 0 {
json.Unmarshal(scheduleJSON, &task.Schedule)
}
if message.Valid {
task.Message = message.String
}
if errStr.Valid {
task.Error = errStr.String
}
if nextRunAt.Valid {
task.NextRunAt = &nextRunAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
tasks = append(tasks, task)
}
return tasks, nil
}
func (r *ComputerTaskRepo) Delete(ctx context.Context, id string) error {
query := `DELETE FROM computer_tasks WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, id)
return err
}
func (r *ComputerTaskRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
query := `
DELETE FROM computer_tasks
WHERE created_at < NOW() - INTERVAL '%d days'
AND status IN ('completed', 'failed', 'cancelled')
`
result, err := r.db.ExecContext(ctx, fmt.Sprintf(query, days))
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (r *ComputerTaskRepo) CountByUser(ctx context.Context, userID string) (int64, error) {
query := `SELECT COUNT(*) FROM computer_tasks WHERE user_id = $1`
var count int64
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
return count, err
}
func (r *ComputerTaskRepo) CountByStatus(ctx context.Context, status string) (int64, error) {
query := `SELECT COUNT(*) FROM computer_tasks WHERE status = $1`
var count int64
err := r.db.QueryRowContext(ctx, query, status).Scan(&count)
return count, err
}

View File

@@ -0,0 +1,177 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type DigestCitation struct {
Index int `json:"index"`
URL string `json:"url"`
Title string `json:"title"`
Domain string `json:"domain"`
}
type Digest struct {
ID int64 `json:"id"`
Topic string `json:"topic"`
Region string `json:"region"`
ClusterTitle string `json:"clusterTitle"`
SummaryRu string `json:"summaryRu"`
Citations []DigestCitation `json:"citations"`
SourcesCount int `json:"sourcesCount"`
FollowUp []string `json:"followUp"`
Thumbnail string `json:"thumbnail"`
ShortDescription string `json:"shortDescription"`
MainURL string `json:"mainUrl"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type DigestRepository struct {
db *PostgresDB
}
func NewDigestRepository(db *PostgresDB) *DigestRepository {
return &DigestRepository{db: db}
}
func (r *DigestRepository) GetByTopicRegionTitle(ctx context.Context, topic, region, title string) (*Digest, error) {
query := `
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
follow_up, thumbnail, short_description, main_url, created_at, updated_at
FROM digests
WHERE topic = $1 AND region = $2 AND cluster_title = $3
`
var d Digest
var citationsJSON, followUpJSON []byte
err := r.db.db.QueryRowContext(ctx, query, topic, region, title).Scan(
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
&citationsJSON, &d.SourcesCount, &followUpJSON,
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
&d.CreatedAt, &d.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(citationsJSON, &d.Citations)
json.Unmarshal(followUpJSON, &d.FollowUp)
return &d, nil
}
func (r *DigestRepository) GetByURL(ctx context.Context, url string) (*Digest, error) {
query := `
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
follow_up, thumbnail, short_description, main_url, created_at, updated_at
FROM digests
WHERE main_url = $1
LIMIT 1
`
var d Digest
var citationsJSON, followUpJSON []byte
err := r.db.db.QueryRowContext(ctx, query, url).Scan(
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
&citationsJSON, &d.SourcesCount, &followUpJSON,
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
&d.CreatedAt, &d.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(citationsJSON, &d.Citations)
json.Unmarshal(followUpJSON, &d.FollowUp)
return &d, nil
}
func (r *DigestRepository) GetByTopicRegion(ctx context.Context, topic, region string, limit int) ([]*Digest, error) {
query := `
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
follow_up, thumbnail, short_description, main_url, created_at, updated_at
FROM digests
WHERE topic = $1 AND region = $2
ORDER BY created_at DESC
LIMIT $3
`
rows, err := r.db.db.QueryContext(ctx, query, topic, region, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var digests []*Digest
for rows.Next() {
var d Digest
var citationsJSON, followUpJSON []byte
if err := rows.Scan(
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
&citationsJSON, &d.SourcesCount, &followUpJSON,
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
&d.CreatedAt, &d.UpdatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(citationsJSON, &d.Citations)
json.Unmarshal(followUpJSON, &d.FollowUp)
digests = append(digests, &d)
}
return digests, nil
}
func (r *DigestRepository) Upsert(ctx context.Context, d *Digest) error {
citationsJSON, _ := json.Marshal(d.Citations)
followUpJSON, _ := json.Marshal(d.FollowUp)
query := `
INSERT INTO digests (topic, region, cluster_title, summary_ru, citations, sources_count,
follow_up, thumbnail, short_description, main_url)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (topic, region, cluster_title)
DO UPDATE SET
summary_ru = EXCLUDED.summary_ru,
citations = EXCLUDED.citations,
sources_count = EXCLUDED.sources_count,
follow_up = EXCLUDED.follow_up,
thumbnail = EXCLUDED.thumbnail,
short_description = EXCLUDED.short_description,
main_url = EXCLUDED.main_url,
updated_at = NOW()
`
_, err := r.db.db.ExecContext(ctx, query,
d.Topic, d.Region, d.ClusterTitle, d.SummaryRu,
citationsJSON, d.SourcesCount, followUpJSON,
d.Thumbnail, d.ShortDescription, d.MainURL,
)
return err
}
func (r *DigestRepository) DeleteByTopicRegion(ctx context.Context, topic, region string) (int64, error) {
result, err := r.db.db.ExecContext(ctx,
"DELETE FROM digests WHERE topic = $1 AND region = $2",
topic, region,
)
if err != nil {
return 0, err
}
return result.RowsAffected()
}

View File

@@ -0,0 +1,149 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type UploadedFile struct {
ID string `json:"id"`
UserID string `json:"userId"`
Filename string `json:"filename"`
FileType string `json:"fileType"`
FileSize int64 `json:"fileSize"`
StoragePath string `json:"storagePath"`
ExtractedText string `json:"extractedText,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"createdAt"`
}
type FileRepository struct {
db *PostgresDB
}
func NewFileRepository(db *PostgresDB) *FileRepository {
return &FileRepository{db: db}
}
func (r *FileRepository) Create(ctx context.Context, f *UploadedFile) error {
metadataJSON, _ := json.Marshal(f.Metadata)
query := `
INSERT INTO uploaded_files (user_id, filename, file_type, file_size, storage_path, extracted_text, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created_at
`
return r.db.db.QueryRowContext(ctx, query,
f.UserID, f.Filename, f.FileType, f.FileSize, f.StoragePath, f.ExtractedText, metadataJSON,
).Scan(&f.ID, &f.CreatedAt)
}
func (r *FileRepository) GetByID(ctx context.Context, id string) (*UploadedFile, error) {
query := `
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
FROM uploaded_files
WHERE id = $1
`
var f UploadedFile
var metadataJSON []byte
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(metadataJSON, &f.Metadata)
return &f, nil
}
func (r *FileRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*UploadedFile, error) {
query := `
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
FROM uploaded_files
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var files []*UploadedFile
for rows.Next() {
var f UploadedFile
var metadataJSON []byte
if err := rows.Scan(
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(metadataJSON, &f.Metadata)
files = append(files, &f)
}
return files, nil
}
func (r *FileRepository) UpdateExtractedText(ctx context.Context, id, text string) error {
_, err := r.db.db.ExecContext(ctx,
"UPDATE uploaded_files SET extracted_text = $2 WHERE id = $1",
id, text,
)
return err
}
func (r *FileRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM uploaded_files WHERE id = $1", id)
return err
}
func (r *FileRepository) GetByIDs(ctx context.Context, ids []string) ([]*UploadedFile, error) {
if len(ids) == 0 {
return nil, nil
}
query := `
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
FROM uploaded_files
WHERE id = ANY($1)
`
rows, err := r.db.db.QueryContext(ctx, query, ids)
if err != nil {
return nil, err
}
defer rows.Close()
var files []*UploadedFile
for rows.Next() {
var f UploadedFile
var metadataJSON []byte
if err := rows.Scan(
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(metadataJSON, &f.Metadata)
files = append(files, &f)
}
return files, nil
}

View File

@@ -0,0 +1,170 @@
package db
import (
"context"
"encoding/json"
"time"
)
type UserMemory struct {
ID string `json:"id"`
UserID string `json:"userId"`
MemoryType string `json:"memoryType"`
Key string `json:"key"`
Value string `json:"value"`
Metadata map[string]interface{} `json:"metadata"`
Importance int `json:"importance"`
LastUsed time.Time `json:"lastUsed"`
UseCount int `json:"useCount"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type MemoryRepository struct {
db *PostgresDB
}
func NewMemoryRepository(db *PostgresDB) *MemoryRepository {
return &MemoryRepository{db: db}
}
func (r *MemoryRepository) RunMigrations(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS user_memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
memory_type VARCHAR(50) NOT NULL,
key VARCHAR(255) NOT NULL,
value TEXT NOT NULL,
metadata JSONB DEFAULT '{}',
importance INT DEFAULT 5,
last_used TIMESTAMPTZ DEFAULT NOW(),
use_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id, memory_type, key)
)`,
`CREATE INDEX IF NOT EXISTS idx_user_memories_user ON user_memories(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_user_memories_type ON user_memories(user_id, memory_type)`,
`CREATE INDEX IF NOT EXISTS idx_user_memories_importance ON user_memories(user_id, importance DESC)`,
}
for _, m := range migrations {
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *MemoryRepository) Save(ctx context.Context, mem *UserMemory) error {
metadataJSON, _ := json.Marshal(mem.Metadata)
query := `
INSERT INTO user_memories (user_id, memory_type, key, value, metadata, importance)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_id, memory_type, key)
DO UPDATE SET
value = EXCLUDED.value,
metadata = EXCLUDED.metadata,
importance = EXCLUDED.importance,
updated_at = NOW()
RETURNING id, created_at, updated_at
`
return r.db.db.QueryRowContext(ctx, query,
mem.UserID, mem.MemoryType, mem.Key, mem.Value, metadataJSON, mem.Importance,
).Scan(&mem.ID, &mem.CreatedAt, &mem.UpdatedAt)
}
func (r *MemoryRepository) GetByUserID(ctx context.Context, userID string, memoryType string, limit int) ([]*UserMemory, error) {
query := `
SELECT id, user_id, memory_type, key, value, metadata, importance, last_used, use_count, created_at, updated_at
FROM user_memories
WHERE user_id = $1
`
args := []interface{}{userID}
if memoryType != "" {
query += " AND memory_type = $2"
args = append(args, memoryType)
}
query += " ORDER BY importance DESC, last_used DESC"
if limit > 0 {
query += " LIMIT $" + string(rune('0'+len(args)+1))
args = append(args, limit)
}
rows, err := r.db.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var memories []*UserMemory
for rows.Next() {
var mem UserMemory
var metadataJSON []byte
if err := rows.Scan(
&mem.ID, &mem.UserID, &mem.MemoryType, &mem.Key, &mem.Value,
&metadataJSON, &mem.Importance, &mem.LastUsed, &mem.UseCount,
&mem.CreatedAt, &mem.UpdatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(metadataJSON, &mem.Metadata)
memories = append(memories, &mem)
}
return memories, nil
}
func (r *MemoryRepository) GetContextForUser(ctx context.Context, userID string) (string, error) {
memories, err := r.GetByUserID(ctx, userID, "", 20)
if err != nil {
return "", err
}
var context string
for _, mem := range memories {
switch mem.MemoryType {
case "preference":
context += "User preference: " + mem.Key + " = " + mem.Value + "\n"
case "fact":
context += "Known fact about user: " + mem.Value + "\n"
case "instruction":
context += "User instruction: " + mem.Value + "\n"
case "interest":
context += "User interest: " + mem.Value + "\n"
default:
context += mem.Key + ": " + mem.Value + "\n"
}
}
return context, nil
}
func (r *MemoryRepository) IncrementUseCount(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx,
"UPDATE user_memories SET use_count = use_count + 1, last_used = NOW() WHERE id = $1",
id,
)
return err
}
func (r *MemoryRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE id = $1", id)
return err
}
func (r *MemoryRepository) DeleteByUserID(ctx context.Context, userID string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE user_id = $1", userID)
return err
}
func ExtractMemoriesFromConversation(ctx context.Context, llmClient interface{}, conversation, answer string) ([]UserMemory, error) {
return nil, nil
}

View File

@@ -0,0 +1,219 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"github.com/gooseek/backend/internal/pages"
)
type PageRepository struct {
db *PostgresDB
}
func NewPageRepository(db *PostgresDB) *PageRepository {
return &PageRepository{db: db}
}
func (r *PageRepository) RunMigrations(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS pages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
thread_id UUID REFERENCES threads(id) ON DELETE SET NULL,
title VARCHAR(500) NOT NULL,
subtitle TEXT,
sections JSONB NOT NULL DEFAULT '[]',
sources JSONB NOT NULL DEFAULT '[]',
thumbnail TEXT,
is_public BOOLEAN DEFAULT FALSE,
share_id VARCHAR(100) UNIQUE,
view_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_pages_user ON pages(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_pages_share ON pages(share_id)`,
}
for _, m := range migrations {
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *PageRepository) Create(ctx context.Context, p *pages.Page) error {
sectionsJSON, _ := json.Marshal(p.Sections)
sourcesJSON, _ := json.Marshal(p.Sources)
query := `
INSERT INTO pages (user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at
`
var threadID *string
if p.ThreadID != "" {
threadID = &p.ThreadID
}
return r.db.db.QueryRowContext(ctx, query,
p.UserID, threadID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic,
).Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt)
}
func (r *PageRepository) GetByID(ctx context.Context, id string) (*pages.Page, error) {
query := `
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
FROM pages
WHERE id = $1
`
var p pages.Page
var sectionsJSON, sourcesJSON []byte
var threadID, shareID sql.NullString
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
&sectionsJSON, &sourcesJSON, &p.Thumbnail,
&p.IsPublic, &shareID, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(sectionsJSON, &p.Sections)
json.Unmarshal(sourcesJSON, &p.Sources)
if threadID.Valid {
p.ThreadID = threadID.String
}
if shareID.Valid {
p.ShareID = shareID.String
}
return &p, nil
}
func (r *PageRepository) GetByShareID(ctx context.Context, shareID string) (*pages.Page, error) {
query := `
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
FROM pages
WHERE share_id = $1 AND is_public = true
`
var p pages.Page
var sectionsJSON, sourcesJSON []byte
var threadID, shareIDVal sql.NullString
err := r.db.db.QueryRowContext(ctx, query, shareID).Scan(
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
&sectionsJSON, &sourcesJSON, &p.Thumbnail,
&p.IsPublic, &shareIDVal, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(sectionsJSON, &p.Sections)
json.Unmarshal(sourcesJSON, &p.Sources)
if threadID.Valid {
p.ThreadID = threadID.String
}
if shareIDVal.Valid {
p.ShareID = shareIDVal.String
}
return &p, nil
}
func (r *PageRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*pages.Page, error) {
query := `
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
FROM pages
WHERE user_id = $1
ORDER BY updated_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var pagesList []*pages.Page
for rows.Next() {
var p pages.Page
var sectionsJSON, sourcesJSON []byte
var threadID, shareID sql.NullString
if err := rows.Scan(
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
&sectionsJSON, &sourcesJSON, &p.Thumbnail,
&p.IsPublic, &shareID, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(sectionsJSON, &p.Sections)
json.Unmarshal(sourcesJSON, &p.Sources)
if threadID.Valid {
p.ThreadID = threadID.String
}
if shareID.Valid {
p.ShareID = shareID.String
}
pagesList = append(pagesList, &p)
}
return pagesList, nil
}
func (r *PageRepository) Update(ctx context.Context, p *pages.Page) error {
sectionsJSON, _ := json.Marshal(p.Sections)
sourcesJSON, _ := json.Marshal(p.Sources)
query := `
UPDATE pages
SET title = $2, subtitle = $3, sections = $4, sources = $5, thumbnail = $6, is_public = $7, updated_at = NOW()
WHERE id = $1
`
_, err := r.db.db.ExecContext(ctx, query,
p.ID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic,
)
return err
}
func (r *PageRepository) SetShareID(ctx context.Context, pageID, shareID string) error {
_, err := r.db.db.ExecContext(ctx,
"UPDATE pages SET share_id = $2, is_public = true WHERE id = $1",
pageID, shareID,
)
return err
}
func (r *PageRepository) IncrementViewCount(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx,
"UPDATE pages SET view_count = view_count + 1 WHERE id = $1",
id,
)
return err
}
func (r *PageRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM pages WHERE id = $1", id)
return err
}

View File

@@ -0,0 +1,134 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
_ "github.com/lib/pq"
)
type PostgresDB struct {
db *sql.DB
}
func NewPostgresDB(databaseURL string) (*PostgresDB, error) {
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &PostgresDB{db: db}, nil
}
func (p *PostgresDB) Close() error {
return p.db.Close()
}
func (p *PostgresDB) DB() *sql.DB {
return p.db
}
func (p *PostgresDB) RunMigrations(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS digests (
id SERIAL PRIMARY KEY,
topic VARCHAR(100) NOT NULL,
region VARCHAR(50) NOT NULL,
cluster_title VARCHAR(500) NOT NULL,
summary_ru TEXT NOT NULL,
citations JSONB DEFAULT '[]',
sources_count INT DEFAULT 0,
follow_up JSONB DEFAULT '[]',
thumbnail TEXT,
short_description TEXT,
main_url TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(topic, region, cluster_title)
)`,
`CREATE INDEX IF NOT EXISTS idx_digests_topic_region ON digests(topic, region)`,
`CREATE INDEX IF NOT EXISTS idx_digests_main_url ON digests(main_url)`,
`CREATE TABLE IF NOT EXISTS article_summaries (
id SERIAL PRIMARY KEY,
url_hash VARCHAR(64) NOT NULL UNIQUE,
url TEXT NOT NULL,
events JSONB NOT NULL DEFAULT '[]',
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days'
)`,
`CREATE INDEX IF NOT EXISTS idx_article_summaries_url_hash ON article_summaries(url_hash)`,
`CREATE INDEX IF NOT EXISTS idx_article_summaries_expires ON article_summaries(expires_at)`,
`CREATE TABLE IF NOT EXISTS collections (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
is_public BOOLEAN DEFAULT FALSE,
context_enabled BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_collections_user ON collections(user_id)`,
`CREATE TABLE IF NOT EXISTS collection_items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
collection_id UUID NOT NULL REFERENCES collections(id) ON DELETE CASCADE,
item_type VARCHAR(50) NOT NULL,
title VARCHAR(500),
content TEXT,
url TEXT,
metadata JSONB DEFAULT '{}',
created_at TIMESTAMPTZ DEFAULT NOW(),
sort_order INT DEFAULT 0
)`,
`CREATE INDEX IF NOT EXISTS idx_collection_items_collection ON collection_items(collection_id)`,
`CREATE TABLE IF NOT EXISTS uploaded_files (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
filename VARCHAR(500) NOT NULL,
file_type VARCHAR(100) NOT NULL,
file_size BIGINT NOT NULL,
storage_path TEXT NOT NULL,
extracted_text TEXT,
metadata JSONB DEFAULT '{}',
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_uploaded_files_user ON uploaded_files(user_id)`,
`CREATE TABLE IF NOT EXISTS research_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID,
collection_id UUID REFERENCES collections(id) ON DELETE SET NULL,
query TEXT NOT NULL,
focus_mode VARCHAR(50) DEFAULT 'all',
optimization_mode VARCHAR(50) DEFAULT 'balanced',
sources JSONB DEFAULT '[]',
response_blocks JSONB DEFAULT '[]',
final_answer TEXT,
citations JSONB DEFAULT '[]',
created_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_research_sessions_user ON research_sessions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_research_sessions_collection ON research_sessions(collection_id)`,
}
for _, migration := range migrations {
if _, err := p.db.ExecContext(ctx, migration); err != nil {
return fmt.Errorf("migration failed: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,163 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type Space struct {
ID string `json:"id"`
UserID string `json:"userId"`
Name string `json:"name"`
Description string `json:"description"`
Icon string `json:"icon"`
Color string `json:"color"`
CustomInstructions string `json:"customInstructions"`
DefaultFocusMode string `json:"defaultFocusMode"`
DefaultModel string `json:"defaultModel"`
IsPublic bool `json:"isPublic"`
Settings map[string]interface{} `json:"settings"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
ThreadCount int `json:"threadCount,omitempty"`
}
type SpaceRepository struct {
db *PostgresDB
}
func NewSpaceRepository(db *PostgresDB) *SpaceRepository {
return &SpaceRepository{db: db}
}
func (r *SpaceRepository) RunMigrations(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS spaces (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
icon VARCHAR(50),
color VARCHAR(20),
custom_instructions TEXT,
default_focus_mode VARCHAR(50) DEFAULT 'all',
default_model VARCHAR(100),
is_public BOOLEAN DEFAULT FALSE,
settings JSONB DEFAULT '{}',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_spaces_user ON spaces(user_id)`,
}
for _, m := range migrations {
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *SpaceRepository) Create(ctx context.Context, s *Space) error {
settingsJSON, _ := json.Marshal(s.Settings)
query := `
INSERT INTO spaces (user_id, name, description, icon, color, custom_instructions, default_focus_mode, default_model, is_public, settings)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id, created_at, updated_at
`
return r.db.db.QueryRowContext(ctx, query,
s.UserID, s.Name, s.Description, s.Icon, s.Color,
s.CustomInstructions, s.DefaultFocusMode, s.DefaultModel,
s.IsPublic, settingsJSON,
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
}
func (r *SpaceRepository) GetByID(ctx context.Context, id string) (*Space, error) {
query := `
SELECT id, user_id, name, description, icon, color, custom_instructions,
default_focus_mode, default_model, is_public, settings, created_at, updated_at,
(SELECT COUNT(*) FROM threads WHERE space_id = spaces.id) as thread_count
FROM spaces
WHERE id = $1
`
var s Space
var settingsJSON []byte
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
&s.ID, &s.UserID, &s.Name, &s.Description, &s.Icon, &s.Color,
&s.CustomInstructions, &s.DefaultFocusMode, &s.DefaultModel,
&s.IsPublic, &settingsJSON, &s.CreatedAt, &s.UpdatedAt, &s.ThreadCount,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal(settingsJSON, &s.Settings)
return &s, nil
}
func (r *SpaceRepository) GetByUserID(ctx context.Context, userID string) ([]*Space, error) {
query := `
SELECT id, user_id, name, description, icon, color, custom_instructions,
default_focus_mode, default_model, is_public, settings, created_at, updated_at,
(SELECT COUNT(*) FROM threads WHERE space_id = spaces.id) as thread_count
FROM spaces
WHERE user_id = $1
ORDER BY updated_at DESC
`
rows, err := r.db.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var spaces []*Space
for rows.Next() {
var s Space
var settingsJSON []byte
if err := rows.Scan(
&s.ID, &s.UserID, &s.Name, &s.Description, &s.Icon, &s.Color,
&s.CustomInstructions, &s.DefaultFocusMode, &s.DefaultModel,
&s.IsPublic, &settingsJSON, &s.CreatedAt, &s.UpdatedAt, &s.ThreadCount,
); err != nil {
return nil, err
}
json.Unmarshal(settingsJSON, &s.Settings)
spaces = append(spaces, &s)
}
return spaces, nil
}
func (r *SpaceRepository) Update(ctx context.Context, s *Space) error {
settingsJSON, _ := json.Marshal(s.Settings)
query := `
UPDATE spaces
SET name = $2, description = $3, icon = $4, color = $5,
custom_instructions = $6, default_focus_mode = $7, default_model = $8,
is_public = $9, settings = $10, updated_at = NOW()
WHERE id = $1
`
_, err := r.db.db.ExecContext(ctx, query,
s.ID, s.Name, s.Description, s.Icon, s.Color,
s.CustomInstructions, s.DefaultFocusMode, s.DefaultModel,
s.IsPublic, settingsJSON,
)
return err
}
func (r *SpaceRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM spaces WHERE id = $1", id)
return err
}

View File

@@ -0,0 +1,270 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type Thread struct {
ID string `json:"id"`
UserID string `json:"userId"`
SpaceID *string `json:"spaceId,omitempty"`
Title string `json:"title"`
FocusMode string `json:"focusMode"`
IsPublic bool `json:"isPublic"`
ShareID *string `json:"shareId,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Messages []ThreadMessage `json:"messages,omitempty"`
MessageCount int `json:"messageCount,omitempty"`
}
type ThreadMessage struct {
ID string `json:"id"`
ThreadID string `json:"threadId"`
Role string `json:"role"`
Content string `json:"content"`
Sources []ThreadSource `json:"sources,omitempty"`
Widgets []map[string]interface{} `json:"widgets,omitempty"`
RelatedQuestions []string `json:"relatedQuestions,omitempty"`
Model string `json:"model,omitempty"`
TokensUsed int `json:"tokensUsed,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
type ThreadSource struct {
Index int `json:"index"`
URL string `json:"url"`
Title string `json:"title"`
Domain string `json:"domain"`
Snippet string `json:"snippet,omitempty"`
}
type ThreadRepository struct {
db *PostgresDB
}
func NewThreadRepository(db *PostgresDB) *ThreadRepository {
return &ThreadRepository{db: db}
}
func (r *ThreadRepository) RunMigrations(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS threads (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
space_id UUID REFERENCES spaces(id) ON DELETE SET NULL,
title VARCHAR(500) NOT NULL DEFAULT 'New Thread',
focus_mode VARCHAR(50) DEFAULT 'all',
is_public BOOLEAN DEFAULT FALSE,
share_id VARCHAR(100) UNIQUE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_threads_user ON threads(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_threads_space ON threads(space_id)`,
`CREATE INDEX IF NOT EXISTS idx_threads_share ON threads(share_id)`,
`CREATE TABLE IF NOT EXISTS thread_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
thread_id UUID NOT NULL REFERENCES threads(id) ON DELETE CASCADE,
role VARCHAR(20) NOT NULL,
content TEXT NOT NULL,
sources JSONB DEFAULT '[]',
widgets JSONB DEFAULT '[]',
related_questions JSONB DEFAULT '[]',
model VARCHAR(100),
tokens_used INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_thread_messages_thread ON thread_messages(thread_id)`,
}
for _, m := range migrations {
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *ThreadRepository) Create(ctx context.Context, t *Thread) error {
query := `
INSERT INTO threads (user_id, space_id, title, focus_mode, is_public)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, created_at, updated_at
`
return r.db.db.QueryRowContext(ctx, query,
t.UserID, t.SpaceID, t.Title, t.FocusMode, t.IsPublic,
).Scan(&t.ID, &t.CreatedAt, &t.UpdatedAt)
}
func (r *ThreadRepository) GetByID(ctx context.Context, id string) (*Thread, error) {
query := `
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at,
(SELECT COUNT(*) FROM thread_messages WHERE thread_id = threads.id) as message_count
FROM threads
WHERE id = $1
`
var t Thread
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt, &t.MessageCount,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &t, nil
}
func (r *ThreadRepository) GetByShareID(ctx context.Context, shareID string) (*Thread, error) {
query := `
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at
FROM threads
WHERE share_id = $1 AND is_public = true
`
var t Thread
err := r.db.db.QueryRowContext(ctx, query, shareID).Scan(
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &t, nil
}
func (r *ThreadRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*Thread, error) {
query := `
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at,
(SELECT COUNT(*) FROM thread_messages WHERE thread_id = threads.id) as message_count
FROM threads
WHERE user_id = $1
ORDER BY updated_at DESC
LIMIT $2 OFFSET $3
`
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var threads []*Thread
for rows.Next() {
var t Thread
if err := rows.Scan(
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt, &t.MessageCount,
); err != nil {
return nil, err
}
threads = append(threads, &t)
}
return threads, nil
}
func (r *ThreadRepository) Update(ctx context.Context, t *Thread) error {
query := `
UPDATE threads
SET title = $2, focus_mode = $3, is_public = $4, updated_at = NOW()
WHERE id = $1
`
_, err := r.db.db.ExecContext(ctx, query, t.ID, t.Title, t.FocusMode, t.IsPublic)
return err
}
func (r *ThreadRepository) SetShareID(ctx context.Context, threadID, shareID string) error {
_, err := r.db.db.ExecContext(ctx,
"UPDATE threads SET share_id = $2, is_public = true WHERE id = $1",
threadID, shareID,
)
return err
}
func (r *ThreadRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.db.ExecContext(ctx, "DELETE FROM threads WHERE id = $1", id)
return err
}
func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage) error {
sourcesJSON, _ := json.Marshal(msg.Sources)
widgetsJSON, _ := json.Marshal(msg.Widgets)
relatedJSON, _ := json.Marshal(msg.RelatedQuestions)
query := `
INSERT INTO thread_messages (thread_id, role, content, sources, widgets, related_questions, model, tokens_used)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at
`
err := r.db.db.QueryRowContext(ctx, query,
msg.ThreadID, msg.Role, msg.Content, sourcesJSON, widgetsJSON, relatedJSON, msg.Model, msg.TokensUsed,
).Scan(&msg.ID, &msg.CreatedAt)
if err == nil {
r.db.db.ExecContext(ctx, "UPDATE threads SET updated_at = NOW() WHERE id = $1", msg.ThreadID)
}
return err
}
func (r *ThreadRepository) GetMessages(ctx context.Context, threadID string, limit, offset int) ([]ThreadMessage, error) {
query := `
SELECT id, thread_id, role, content, sources, widgets, related_questions, model, tokens_used, created_at
FROM thread_messages
WHERE thread_id = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
rows, err := r.db.db.QueryContext(ctx, query, threadID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []ThreadMessage
for rows.Next() {
var msg ThreadMessage
var sourcesJSON, widgetsJSON, relatedJSON []byte
if err := rows.Scan(
&msg.ID, &msg.ThreadID, &msg.Role, &msg.Content,
&sourcesJSON, &widgetsJSON, &relatedJSON,
&msg.Model, &msg.TokensUsed, &msg.CreatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(sourcesJSON, &msg.Sources)
json.Unmarshal(widgetsJSON, &msg.Widgets)
json.Unmarshal(relatedJSON, &msg.RelatedQuestions)
messages = append(messages, msg)
}
return messages, nil
}
func (r *ThreadRepository) GenerateTitle(ctx context.Context, threadID, firstMessage string) error {
title := firstMessage
if len(title) > 100 {
title = title[:97] + "..."
}
_, err := r.db.db.ExecContext(ctx,
"UPDATE threads SET title = $2 WHERE id = $1",
threadID, title,
)
return err
}

View File

@@ -0,0 +1,323 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"time"
)
type UserInterestsData struct {
UserID string `json:"userId"`
Topics json.RawMessage `json:"topics"`
Sources json.RawMessage `json:"sources"`
Keywords json.RawMessage `json:"keywords"`
ViewHistory json.RawMessage `json:"viewHistory"`
SavedArticles json.RawMessage `json:"savedArticles"`
BlockedSources json.RawMessage `json:"blockedSources"`
BlockedTopics json.RawMessage `json:"blockedTopics"`
PreferredLang string `json:"preferredLang"`
Region string `json:"region"`
ReadingLevel string `json:"readingLevel"`
Notifications json.RawMessage `json:"notifications"`
CustomCategories json.RawMessage `json:"customCategories"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type UserInterestsRepository struct {
db *PostgresDB
}
func NewUserInterestsRepository(db *PostgresDB) *UserInterestsRepository {
return &UserInterestsRepository{db: db}
}
func (r *UserInterestsRepository) createTable(ctx context.Context) error {
query := `
CREATE TABLE IF NOT EXISTS user_interests (
user_id VARCHAR(255) PRIMARY KEY,
topics JSONB DEFAULT '{}',
sources JSONB DEFAULT '{}',
keywords JSONB DEFAULT '{}',
view_history JSONB DEFAULT '[]',
saved_articles JSONB DEFAULT '[]',
blocked_sources JSONB DEFAULT '[]',
blocked_topics JSONB DEFAULT '[]',
preferred_lang VARCHAR(10) DEFAULT 'ru',
region VARCHAR(50) DEFAULT 'russia',
reading_level VARCHAR(20) DEFAULT 'general',
notifications JSONB DEFAULT '{}',
custom_categories JSONB DEFAULT '[]',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_interests_updated ON user_interests(updated_at);
CREATE INDEX IF NOT EXISTS idx_user_interests_region ON user_interests(region);
`
_, err := r.db.DB().ExecContext(ctx, query)
return err
}
func (r *UserInterestsRepository) Get(ctx context.Context, userID string) (*UserInterestsData, error) {
query := `
SELECT user_id, topics, sources, keywords, view_history, saved_articles,
blocked_sources, blocked_topics, preferred_lang, region, reading_level,
notifications, custom_categories, created_at, updated_at
FROM user_interests
WHERE user_id = $1
`
row := r.db.DB().QueryRowContext(ctx, query, userID)
var data UserInterestsData
err := row.Scan(
&data.UserID, &data.Topics, &data.Sources, &data.Keywords,
&data.ViewHistory, &data.SavedArticles, &data.BlockedSources,
&data.BlockedTopics, &data.PreferredLang, &data.Region,
&data.ReadingLevel, &data.Notifications, &data.CustomCategories,
&data.CreatedAt, &data.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &data, nil
}
func (r *UserInterestsRepository) Save(ctx context.Context, data *UserInterestsData) error {
query := `
INSERT INTO user_interests (
user_id, topics, sources, keywords, view_history, saved_articles,
blocked_sources, blocked_topics, preferred_lang, region, reading_level,
notifications, custom_categories, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
ON CONFLICT (user_id) DO UPDATE SET
topics = EXCLUDED.topics,
sources = EXCLUDED.sources,
keywords = EXCLUDED.keywords,
view_history = EXCLUDED.view_history,
saved_articles = EXCLUDED.saved_articles,
blocked_sources = EXCLUDED.blocked_sources,
blocked_topics = EXCLUDED.blocked_topics,
preferred_lang = EXCLUDED.preferred_lang,
region = EXCLUDED.region,
reading_level = EXCLUDED.reading_level,
notifications = EXCLUDED.notifications,
custom_categories = EXCLUDED.custom_categories,
updated_at = NOW()
`
now := time.Now()
if data.CreatedAt.IsZero() {
data.CreatedAt = now
}
data.UpdatedAt = now
if data.Topics == nil {
data.Topics = json.RawMessage("{}")
}
if data.Sources == nil {
data.Sources = json.RawMessage("{}")
}
if data.Keywords == nil {
data.Keywords = json.RawMessage("{}")
}
if data.ViewHistory == nil {
data.ViewHistory = json.RawMessage("[]")
}
if data.SavedArticles == nil {
data.SavedArticles = json.RawMessage("[]")
}
if data.BlockedSources == nil {
data.BlockedSources = json.RawMessage("[]")
}
if data.BlockedTopics == nil {
data.BlockedTopics = json.RawMessage("[]")
}
if data.Notifications == nil {
data.Notifications = json.RawMessage("{}")
}
if data.CustomCategories == nil {
data.CustomCategories = json.RawMessage("[]")
}
_, err := r.db.DB().ExecContext(ctx, query,
data.UserID, data.Topics, data.Sources, data.Keywords,
data.ViewHistory, data.SavedArticles, data.BlockedSources,
data.BlockedTopics, data.PreferredLang, data.Region,
data.ReadingLevel, data.Notifications, data.CustomCategories,
data.CreatedAt, data.UpdatedAt,
)
return err
}
func (r *UserInterestsRepository) Delete(ctx context.Context, userID string) error {
query := `DELETE FROM user_interests WHERE user_id = $1`
_, err := r.db.DB().ExecContext(ctx, query, userID)
return err
}
func (r *UserInterestsRepository) AddViewEvent(ctx context.Context, userID string, event json.RawMessage) error {
query := `
UPDATE user_interests
SET view_history = CASE
WHEN jsonb_array_length(view_history) >= 500
THEN jsonb_build_array($2) || view_history[0:499]
ELSE jsonb_build_array($2) || view_history
END,
updated_at = NOW()
WHERE user_id = $1
`
result, err := r.db.DB().ExecContext(ctx, query, userID, event)
if err != nil {
return err
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
insertQuery := `
INSERT INTO user_interests (user_id, view_history, updated_at)
VALUES ($1, jsonb_build_array($2), NOW())
`
_, err = r.db.DB().ExecContext(ctx, insertQuery, userID, event)
}
return err
}
func (r *UserInterestsRepository) UpdateTopicScore(ctx context.Context, userID, topic string, delta float64) error {
query := `
UPDATE user_interests
SET topics = topics || jsonb_build_object($2, COALESCE((topics->>$2)::float, 0) + $3),
updated_at = NOW()
WHERE user_id = $1
`
result, err := r.db.DB().ExecContext(ctx, query, userID, topic, delta)
if err != nil {
return err
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
insertQuery := `
INSERT INTO user_interests (user_id, topics, updated_at)
VALUES ($1, jsonb_build_object($2, $3), NOW())
`
_, err = r.db.DB().ExecContext(ctx, insertQuery, userID, topic, delta)
}
return err
}
func (r *UserInterestsRepository) SaveArticle(ctx context.Context, userID, articleURL string) error {
query := `
UPDATE user_interests
SET saved_articles = CASE
WHEN NOT saved_articles ? $2
THEN saved_articles || jsonb_build_array($2)
ELSE saved_articles
END,
updated_at = NOW()
WHERE user_id = $1
`
_, err := r.db.DB().ExecContext(ctx, query, userID, articleURL)
return err
}
func (r *UserInterestsRepository) UnsaveArticle(ctx context.Context, userID, articleURL string) error {
query := `
UPDATE user_interests
SET saved_articles = saved_articles - $2,
updated_at = NOW()
WHERE user_id = $1
`
_, err := r.db.DB().ExecContext(ctx, query, userID, articleURL)
return err
}
func (r *UserInterestsRepository) BlockSource(ctx context.Context, userID, source string) error {
query := `
UPDATE user_interests
SET blocked_sources = CASE
WHEN NOT blocked_sources ? $2
THEN blocked_sources || jsonb_build_array($2)
ELSE blocked_sources
END,
updated_at = NOW()
WHERE user_id = $1
`
_, err := r.db.DB().ExecContext(ctx, query, userID, source)
return err
}
func (r *UserInterestsRepository) UnblockSource(ctx context.Context, userID, source string) error {
query := `
UPDATE user_interests
SET blocked_sources = blocked_sources - $2,
updated_at = NOW()
WHERE user_id = $1
`
_, err := r.db.DB().ExecContext(ctx, query, userID, source)
return err
}
func (r *UserInterestsRepository) GetTopUsers(ctx context.Context, limit int) ([]string, error) {
query := `
SELECT user_id FROM user_interests
ORDER BY updated_at DESC
LIMIT $1
`
rows, err := r.db.DB().QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
func (r *UserInterestsRepository) DecayAllInterests(ctx context.Context, decayFactor float64) error {
query := `
UPDATE user_interests
SET topics = (
SELECT jsonb_object_agg(key, (value::text::float * $1))
FROM jsonb_each(topics) WHERE (value::text::float * $1) > 0.01
),
sources = (
SELECT jsonb_object_agg(key, (value::text::float * $1))
FROM jsonb_each(sources) WHERE (value::text::float * $1) > 0.01
),
keywords = (
SELECT jsonb_object_agg(key, (value::text::float * $1))
FROM jsonb_each(keywords) WHERE (value::text::float * $1) > 0.01
),
updated_at = NOW()
`
_, err := r.db.DB().ExecContext(ctx, query, decayFactor)
return err
}

View File

@@ -0,0 +1,691 @@
package discover
import (
"context"
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"sync"
"time"
)
type UserInterests struct {
UserID string `json:"userId"`
Topics map[string]float64 `json:"topics"`
Sources map[string]float64 `json:"sources"`
Keywords map[string]float64 `json:"keywords"`
ViewHistory []ViewEvent `json:"viewHistory"`
SavedArticles []string `json:"savedArticles"`
BlockedSources []string `json:"blockedSources"`
BlockedTopics []string `json:"blockedTopics"`
PreferredLang string `json:"preferredLang"`
Region string `json:"region"`
ReadingLevel string `json:"readingLevel"`
Notifications NotificationPrefs `json:"notifications"`
LastUpdated time.Time `json:"lastUpdated"`
CustomCategories []CustomCategory `json:"customCategories,omitempty"`
}
type ViewEvent struct {
ArticleID string `json:"articleId"`
URL string `json:"url"`
Topic string `json:"topic"`
Source string `json:"source"`
Keywords []string `json:"keywords"`
TimeSpent int `json:"timeSpentSeconds"`
Completed bool `json:"completed"`
Saved bool `json:"saved"`
Shared bool `json:"shared"`
Timestamp time.Time `json:"timestamp"`
Engagement float64 `json:"engagement"`
}
type NotificationPrefs struct {
Enabled bool `json:"enabled"`
DailyDigest bool `json:"dailyDigest"`
DigestTime string `json:"digestTime"`
BreakingNews bool `json:"breakingNews"`
TopicAlerts []string `json:"topicAlerts"`
Frequency string `json:"frequency"`
}
type CustomCategory struct {
ID string `json:"id"`
Name string `json:"name"`
Keywords []string `json:"keywords"`
Sources []string `json:"sources"`
Weight float64 `json:"weight"`
}
type PersonalizedFeed struct {
UserID string `json:"userId"`
Items []FeedItem `json:"items"`
Categories []FeedCategory `json:"categories"`
TrendingIn []string `json:"trendingIn"`
UpdatedAt time.Time `json:"updatedAt"`
NextUpdate time.Time `json:"nextUpdate"`
}
type FeedItem struct {
ID string `json:"id"`
URL string `json:"url"`
Title string `json:"title"`
Summary string `json:"summary"`
Thumbnail string `json:"thumbnail"`
Source string `json:"source"`
SourceLogo string `json:"sourceLogo"`
Topic string `json:"topic"`
Keywords []string `json:"keywords"`
PublishedAt time.Time `json:"publishedAt"`
RelevanceScore float64 `json:"relevanceScore"`
Reason string `json:"reason"`
SourcesCount int `json:"sourcesCount"`
ReadTime int `json:"readTimeMinutes"`
HasDigest bool `json:"hasDigest"`
IsBreaking bool `json:"isBreaking"`
IsTrending bool `json:"isTrending"`
IsSaved bool `json:"isSaved"`
IsRead bool `json:"isRead"`
}
type FeedCategory struct {
ID string `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
Color string `json:"color"`
Items []FeedItem `json:"items"`
IsCustom bool `json:"isCustom"`
}
type PersonalizationEngine struct {
userStore UserInterestStore
contentRepo ContentRepository
mu sync.RWMutex
config PersonalizationConfig
}
type PersonalizationConfig struct {
MaxFeedItems int
DecayFactor float64
RecencyWeight float64
EngagementWeight float64
TopicMatchWeight float64
SourceTrustWeight float64
DiversityFactor float64
TrendingBoost float64
BreakingBoost float64
}
type UserInterestStore interface {
Get(ctx context.Context, userID string) (*UserInterests, error)
Save(ctx context.Context, interests *UserInterests) error
Delete(ctx context.Context, userID string) error
}
type ContentRepository interface {
GetLatestContent(ctx context.Context, topics []string, limit int) ([]FeedItem, error)
GetTrending(ctx context.Context, region string, limit int) ([]FeedItem, error)
GetByKeywords(ctx context.Context, keywords []string, limit int) ([]FeedItem, error)
}
func DefaultConfig() PersonalizationConfig {
return PersonalizationConfig{
MaxFeedItems: 50,
DecayFactor: 0.95,
RecencyWeight: 0.25,
EngagementWeight: 0.20,
TopicMatchWeight: 0.30,
SourceTrustWeight: 0.15,
DiversityFactor: 0.10,
TrendingBoost: 1.5,
BreakingBoost: 2.0,
}
}
func NewPersonalizationEngine(userStore UserInterestStore, contentRepo ContentRepository, cfg PersonalizationConfig) *PersonalizationEngine {
return &PersonalizationEngine{
userStore: userStore,
contentRepo: contentRepo,
config: cfg,
}
}
func (e *PersonalizationEngine) GenerateForYouFeed(ctx context.Context, userID string) (*PersonalizedFeed, error) {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
interests = &UserInterests{
UserID: userID,
Topics: make(map[string]float64),
Sources: make(map[string]float64),
Keywords: make(map[string]float64),
PreferredLang: "ru",
Region: "russia",
}
}
var allItems []FeedItem
var mu sync.Mutex
var wg sync.WaitGroup
topTopics := e.getTopInterests(interests.Topics, 5)
wg.Add(1)
go func() {
defer wg.Done()
items, _ := e.contentRepo.GetLatestContent(ctx, topTopics, 30)
mu.Lock()
allItems = append(allItems, items...)
mu.Unlock()
}()
wg.Add(1)
go func() {
defer wg.Done()
items, _ := e.contentRepo.GetTrending(ctx, interests.Region, 20)
for i := range items {
items[i].IsTrending = true
}
mu.Lock()
allItems = append(allItems, items...)
mu.Unlock()
}()
topKeywords := e.getTopKeywords(interests.Keywords, 10)
if len(topKeywords) > 0 {
wg.Add(1)
go func() {
defer wg.Done()
items, _ := e.contentRepo.GetByKeywords(ctx, topKeywords, 15)
mu.Lock()
allItems = append(allItems, items...)
mu.Unlock()
}()
}
wg.Wait()
allItems = e.deduplicateItems(allItems)
allItems = e.filterBlockedContent(allItems, interests)
for i := range allItems {
allItems[i].RelevanceScore = e.calculateRelevance(allItems[i], interests)
allItems[i].Reason = e.explainRecommendation(allItems[i], interests)
allItems[i].IsRead = e.isArticleRead(allItems[i].URL, interests)
allItems[i].IsSaved = e.isArticleSaved(allItems[i].URL, interests)
}
sort.Slice(allItems, func(i, j int) bool {
return allItems[i].RelevanceScore > allItems[j].RelevanceScore
})
allItems = e.applyDiversity(allItems)
if len(allItems) > e.config.MaxFeedItems {
allItems = allItems[:e.config.MaxFeedItems]
}
categories := e.groupByCategory(allItems, interests)
return &PersonalizedFeed{
UserID: userID,
Items: allItems,
Categories: categories,
TrendingIn: topTopics,
UpdatedAt: time.Now(),
NextUpdate: time.Now().Add(15 * time.Minute),
}, nil
}
func (e *PersonalizationEngine) RecordView(ctx context.Context, userID string, event ViewEvent) error {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
interests = &UserInterests{
UserID: userID,
Topics: make(map[string]float64),
Sources: make(map[string]float64),
Keywords: make(map[string]float64),
}
}
event.Engagement = e.calculateEngagement(event)
interests.ViewHistory = append([]ViewEvent{event}, interests.ViewHistory...)
if len(interests.ViewHistory) > 500 {
interests.ViewHistory = interests.ViewHistory[:500]
}
topicWeight := event.Engagement * 0.1
interests.Topics[event.Topic] += topicWeight
sourceWeight := event.Engagement * 0.05
interests.Sources[event.Source] += sourceWeight
keywordWeight := event.Engagement * 0.02
for _, kw := range event.Keywords {
interests.Keywords[kw] += keywordWeight
}
if event.Saved {
interests.SavedArticles = append(interests.SavedArticles, event.URL)
}
interests.LastUpdated = time.Now()
e.decayInterests(interests)
return e.userStore.Save(ctx, interests)
}
func (e *PersonalizationEngine) UpdateTopicPreference(ctx context.Context, userID, topic string, weight float64) error {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
interests = &UserInterests{
UserID: userID,
Topics: make(map[string]float64),
Sources: make(map[string]float64),
Keywords: make(map[string]float64),
}
}
interests.Topics[topic] = weight
interests.LastUpdated = time.Now()
return e.userStore.Save(ctx, interests)
}
func (e *PersonalizationEngine) BlockSource(ctx context.Context, userID, source string) error {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
return err
}
for _, blocked := range interests.BlockedSources {
if blocked == source {
return nil
}
}
interests.BlockedSources = append(interests.BlockedSources, source)
interests.LastUpdated = time.Now()
return e.userStore.Save(ctx, interests)
}
func (e *PersonalizationEngine) BlockTopic(ctx context.Context, userID, topic string) error {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
return err
}
for _, blocked := range interests.BlockedTopics {
if blocked == topic {
return nil
}
}
interests.BlockedTopics = append(interests.BlockedTopics, topic)
delete(interests.Topics, topic)
interests.LastUpdated = time.Now()
return e.userStore.Save(ctx, interests)
}
func (e *PersonalizationEngine) AddCustomCategory(ctx context.Context, userID string, category CustomCategory) error {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
return err
}
interests.CustomCategories = append(interests.CustomCategories, category)
interests.LastUpdated = time.Now()
return e.userStore.Save(ctx, interests)
}
func (e *PersonalizationEngine) GetUserTopics(ctx context.Context, userID string) (map[string]float64, error) {
interests, err := e.userStore.Get(ctx, userID)
if err != nil {
return nil, err
}
return interests.Topics, nil
}
func (e *PersonalizationEngine) calculateRelevance(item FeedItem, interests *UserInterests) float64 {
score := 0.0
if topicScore, ok := interests.Topics[item.Topic]; ok {
score += topicScore * e.config.TopicMatchWeight
}
if sourceScore, ok := interests.Sources[item.Source]; ok {
score += sourceScore * e.config.SourceTrustWeight
}
keywordScore := 0.0
for _, kw := range item.Keywords {
if kwScore, ok := interests.Keywords[strings.ToLower(kw)]; ok {
keywordScore += kwScore
}
}
score += keywordScore * 0.1
hoursSincePublish := time.Since(item.PublishedAt).Hours()
recencyScore := math.Max(0, 1.0-hoursSincePublish/168.0)
score += recencyScore * e.config.RecencyWeight
if item.IsTrending {
score *= e.config.TrendingBoost
}
if item.IsBreaking {
score *= e.config.BreakingBoost
}
return score
}
func (e *PersonalizationEngine) calculateEngagement(event ViewEvent) float64 {
engagement := 0.0
if event.TimeSpent > 0 {
readTimeScore := math.Min(1.0, float64(event.TimeSpent)/300.0)
engagement += readTimeScore * 0.4
}
if event.Completed {
engagement += 0.3
}
if event.Saved {
engagement += 0.2
}
if event.Shared {
engagement += 0.1
}
return engagement
}
func (e *PersonalizationEngine) explainRecommendation(item FeedItem, interests *UserInterests) string {
if item.IsBreaking {
return "Срочная новость"
}
if item.IsTrending {
return "Популярно сейчас"
}
if topicScore, ok := interests.Topics[item.Topic]; ok && topicScore > 0.5 {
return fmt.Sprintf("Из вашей категории: %s", item.Topic)
}
if sourceScore, ok := interests.Sources[item.Source]; ok && sourceScore > 0.3 {
return fmt.Sprintf("Из источника, который вы читаете: %s", item.Source)
}
for _, kw := range item.Keywords {
if kwScore, ok := interests.Keywords[strings.ToLower(kw)]; ok && kwScore > 0.2 {
return fmt.Sprintf("По вашему интересу: %s", kw)
}
}
return "Рекомендуем для вас"
}
func (e *PersonalizationEngine) getTopInterests(interests map[string]float64, limit int) []string {
type kv struct {
Key string
Value float64
}
var sorted []kv
for k, v := range interests {
sorted = append(sorted, kv{k, v})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Value > sorted[j].Value
})
result := make([]string, 0, limit)
for i, item := range sorted {
if i >= limit {
break
}
result = append(result, item.Key)
}
return result
}
func (e *PersonalizationEngine) getTopKeywords(keywords map[string]float64, limit int) []string {
return e.getTopInterests(keywords, limit)
}
func (e *PersonalizationEngine) deduplicateItems(items []FeedItem) []FeedItem {
seen := make(map[string]bool)
result := make([]FeedItem, 0, len(items))
for _, item := range items {
if !seen[item.URL] {
seen[item.URL] = true
result = append(result, item)
}
}
return result
}
func (e *PersonalizationEngine) filterBlockedContent(items []FeedItem, interests *UserInterests) []FeedItem {
blockedSources := make(map[string]bool)
for _, s := range interests.BlockedSources {
blockedSources[strings.ToLower(s)] = true
}
blockedTopics := make(map[string]bool)
for _, t := range interests.BlockedTopics {
blockedTopics[strings.ToLower(t)] = true
}
result := make([]FeedItem, 0, len(items))
for _, item := range items {
if blockedSources[strings.ToLower(item.Source)] {
continue
}
if blockedTopics[strings.ToLower(item.Topic)] {
continue
}
result = append(result, item)
}
return result
}
func (e *PersonalizationEngine) applyDiversity(items []FeedItem) []FeedItem {
if len(items) <= 10 {
return items
}
topicCounts := make(map[string]int)
sourceCounts := make(map[string]int)
maxPerTopic := len(items) / 5
maxPerSource := len(items) / 4
if maxPerTopic < 3 {
maxPerTopic = 3
}
if maxPerSource < 3 {
maxPerSource = 3
}
result := make([]FeedItem, 0, len(items))
deferred := make([]FeedItem, 0)
for _, item := range items {
if topicCounts[item.Topic] >= maxPerTopic || sourceCounts[item.Source] >= maxPerSource {
deferred = append(deferred, item)
continue
}
topicCounts[item.Topic]++
sourceCounts[item.Source]++
result = append(result, item)
}
for _, item := range deferred {
if len(result) >= e.config.MaxFeedItems {
break
}
result = append(result, item)
}
return result
}
func (e *PersonalizationEngine) groupByCategory(items []FeedItem, interests *UserInterests) []FeedCategory {
categoryMap := make(map[string][]FeedItem)
for _, item := range items {
categoryMap[item.Topic] = append(categoryMap[item.Topic], item)
}
categories := make([]FeedCategory, 0, len(categoryMap))
categoryMeta := map[string]struct {
Icon string
Color string
}{
"tech": {"💻", "#3B82F6"},
"finance": {"💰", "#10B981"},
"sports": {"⚽", "#F59E0B"},
"politics": {"🏛️", "#6366F1"},
"science": {"🔬", "#8B5CF6"},
"health": {"🏥", "#EC4899"},
"entertainment": {"🎬", "#F97316"},
"world": {"🌍", "#14B8A6"},
"business": {"📊", "#6B7280"},
"culture": {"🎭", "#A855F7"},
}
for topic, topicItems := range categoryMap {
if len(topicItems) < 2 {
continue
}
meta, ok := categoryMeta[strings.ToLower(topic)]
if !ok {
meta = struct {
Icon string
Color string
}{"📰", "#6B7280"}
}
categories = append(categories, FeedCategory{
ID: topic,
Name: topic,
Icon: meta.Icon,
Color: meta.Color,
Items: topicItems,
})
}
for _, custom := range interests.CustomCategories {
customItems := make([]FeedItem, 0)
for _, item := range items {
for _, kw := range custom.Keywords {
if containsKeyword(item, kw) {
customItems = append(customItems, item)
break
}
}
}
if len(customItems) > 0 {
categories = append(categories, FeedCategory{
ID: custom.ID,
Name: custom.Name,
Icon: "⭐",
Color: "#FBBF24",
Items: customItems,
IsCustom: true,
})
}
}
sort.Slice(categories, func(i, j int) bool {
iScore := interests.Topics[categories[i].ID]
jScore := interests.Topics[categories[j].ID]
return iScore > jScore
})
return categories
}
func (e *PersonalizationEngine) decayInterests(interests *UserInterests) {
for k := range interests.Topics {
interests.Topics[k] *= e.config.DecayFactor
if interests.Topics[k] < 0.01 {
delete(interests.Topics, k)
}
}
for k := range interests.Sources {
interests.Sources[k] *= e.config.DecayFactor
if interests.Sources[k] < 0.01 {
delete(interests.Sources, k)
}
}
for k := range interests.Keywords {
interests.Keywords[k] *= e.config.DecayFactor
if interests.Keywords[k] < 0.01 {
delete(interests.Keywords, k)
}
}
}
func (e *PersonalizationEngine) isArticleRead(url string, interests *UserInterests) bool {
for _, event := range interests.ViewHistory {
if event.URL == url {
return true
}
}
return false
}
func (e *PersonalizationEngine) isArticleSaved(url string, interests *UserInterests) bool {
for _, saved := range interests.SavedArticles {
if saved == url {
return true
}
}
return false
}
func containsKeyword(item FeedItem, keyword string) bool {
kw := strings.ToLower(keyword)
if strings.Contains(strings.ToLower(item.Title), kw) {
return true
}
if strings.Contains(strings.ToLower(item.Summary), kw) {
return true
}
for _, itemKw := range item.Keywords {
if strings.ToLower(itemKw) == kw {
return true
}
}
return false
}
func (u *UserInterests) ToJSON() ([]byte, error) {
return json.Marshal(u)
}
func ParseUserInterests(data []byte) (*UserInterests, error) {
var interests UserInterests
if err := json.Unmarshal(data, &interests); err != nil {
return nil, err
}
return &interests, nil
}

View File

@@ -0,0 +1,343 @@
package files
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gooseek/backend/internal/llm"
"github.com/ledongthuc/pdf"
)
type FileAnalyzer struct {
llmClient llm.Client
storagePath string
}
type AnalysisResult struct {
FileType string `json:"fileType"`
ExtractedText string `json:"extractedText"`
Summary string `json:"summary"`
KeyPoints []string `json:"keyPoints"`
Metadata map[string]interface{} `json:"metadata"`
}
func NewFileAnalyzer(llmClient llm.Client, storagePath string) *FileAnalyzer {
if storagePath == "" {
storagePath = "/tmp/gooseek-files"
}
os.MkdirAll(storagePath, 0755)
return &FileAnalyzer{
llmClient: llmClient,
storagePath: storagePath,
}
}
func (fa *FileAnalyzer) AnalyzeFile(ctx context.Context, filePath string, fileType string) (*AnalysisResult, error) {
switch {
case strings.HasPrefix(fileType, "application/pdf"):
return fa.analyzePDF(ctx, filePath)
case strings.HasPrefix(fileType, "image/"):
return fa.analyzeImage(ctx, filePath, fileType)
case strings.HasPrefix(fileType, "text/"):
return fa.analyzeText(ctx, filePath)
default:
return nil, fmt.Errorf("unsupported file type: %s", fileType)
}
}
func (fa *FileAnalyzer) analyzePDF(ctx context.Context, filePath string) (*AnalysisResult, error) {
text, metadata, err := extractPDFContent(filePath)
if err != nil {
return nil, fmt.Errorf("failed to extract PDF content: %w", err)
}
if len(text) > 50000 {
text = text[:50000] + "\n\n[Content truncated...]"
}
summary, keyPoints, err := fa.generateSummary(ctx, text, "PDF document")
if err != nil {
summary = ""
keyPoints = nil
}
return &AnalysisResult{
FileType: "pdf",
ExtractedText: text,
Summary: summary,
KeyPoints: keyPoints,
Metadata: metadata,
}, nil
}
func extractPDFContent(filePath string) (string, map[string]interface{}, error) {
f, r, err := pdf.Open(filePath)
if err != nil {
return "", nil, err
}
defer f.Close()
var textBuilder strings.Builder
numPages := r.NumPage()
for i := 1; i <= numPages; i++ {
p := r.Page(i)
if p.V.IsNull() {
continue
}
text, err := p.GetPlainText(nil)
if err != nil {
continue
}
textBuilder.WriteString(text)
textBuilder.WriteString("\n\n")
if textBuilder.Len() > 100000 {
break
}
}
metadata := map[string]interface{}{
"numPages": numPages,
}
return textBuilder.String(), metadata, nil
}
func (fa *FileAnalyzer) analyzeImage(ctx context.Context, filePath string, mimeType string) (*AnalysisResult, error) {
imageData, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read image: %w", err)
}
base64Image := base64.StdEncoding.EncodeToString(imageData)
description, err := fa.describeImage(ctx, base64Image, mimeType)
if err != nil {
description = "Image analysis unavailable"
}
metadata := map[string]interface{}{
"size": len(imageData),
}
return &AnalysisResult{
FileType: "image",
ExtractedText: description,
Summary: description,
KeyPoints: extractKeyPointsFromDescription(description),
Metadata: metadata,
}, nil
}
func (fa *FileAnalyzer) describeImage(ctx context.Context, base64Image, mimeType string) (string, error) {
prompt := `Analyze this image and provide:
1. A detailed description of what's shown
2. Any text visible in the image (OCR)
3. Key elements and their relationships
4. Any data, charts, or diagrams and their meaning
Be thorough but concise.`
messages := []llm.Message{
{
Role: "user",
Content: prompt,
Images: []llm.ImageContent{
{
Type: mimeType,
Data: base64Image,
IsBase64: true,
},
},
},
}
result, err := fa.llmClient.GenerateText(ctx, llm.StreamRequest{
Messages: messages,
})
if err != nil {
return "", err
}
return result, nil
}
func (fa *FileAnalyzer) analyzeText(ctx context.Context, filePath string) (*AnalysisResult, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
text := string(content)
if len(text) > 50000 {
text = text[:50000] + "\n\n[Content truncated...]"
}
summary, keyPoints, err := fa.generateSummary(ctx, text, "text document")
if err != nil {
summary = ""
keyPoints = nil
}
return &AnalysisResult{
FileType: "text",
ExtractedText: text,
Summary: summary,
KeyPoints: keyPoints,
Metadata: map[string]interface{}{
"size": len(content),
"lineCount": strings.Count(text, "\n") + 1,
},
}, nil
}
func (fa *FileAnalyzer) generateSummary(ctx context.Context, text, docType string) (string, []string, error) {
if len(text) < 100 {
return text, nil, nil
}
truncatedText := text
if len(text) > 15000 {
truncatedText = text[:15000] + "\n\n[Content truncated for analysis...]"
}
prompt := fmt.Sprintf(`Analyze this %s and provide:
1. A concise summary (2-3 paragraphs)
2. 5-7 key points as bullet points
Document content:
%s
Format your response as:
SUMMARY:
[your summary here]
KEY POINTS:
- [point 1]
- [point 2]
...`, docType, truncatedText)
result, err := fa.llmClient.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{
{Role: llm.RoleUser, Content: prompt},
},
})
if err != nil {
return "", nil, err
}
summary, keyPoints := parseSummaryResponse(result)
return summary, keyPoints, nil
}
func parseSummaryResponse(response string) (string, []string) {
var summary string
var keyPoints []string
parts := strings.Split(response, "KEY POINTS:")
if len(parts) >= 2 {
summaryPart := strings.TrimPrefix(parts[0], "SUMMARY:")
summary = strings.TrimSpace(summaryPart)
keyPointsPart := parts[1]
for _, line := range strings.Split(keyPointsPart, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "•") || strings.HasPrefix(line, "*") {
point := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "•"), "*")
point = strings.TrimSpace(point)
if point != "" {
keyPoints = append(keyPoints, point)
}
}
}
} else {
summary = response
}
return summary, keyPoints
}
func extractKeyPointsFromDescription(description string) []string {
var points []string
sentences := strings.Split(description, ".")
for i, s := range sentences {
s = strings.TrimSpace(s)
if len(s) > 20 && i < 5 {
points = append(points, s+".")
}
}
return points
}
func DetectMimeType(filename string, content []byte) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".pdf":
return "application/pdf"
case ".png":
return "image/png"
case ".jpg", ".jpeg":
return "image/jpeg"
case ".gif":
return "image/gif"
case ".webp":
return "image/webp"
case ".txt":
return "text/plain"
case ".md":
return "text/markdown"
case ".csv":
return "text/csv"
case ".json":
return "application/json"
default:
return http.DetectContentType(content[:min(512, len(content))])
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (fa *FileAnalyzer) SaveFile(filename string, content io.Reader) (string, int64, error) {
safeName := filepath.Base(filename)
destPath := filepath.Join(fa.storagePath, safeName)
file, err := os.Create(destPath)
if err != nil {
return "", 0, err
}
defer file.Close()
var buf bytes.Buffer
size, err := io.Copy(io.MultiWriter(file, &buf), content)
if err != nil {
return "", 0, err
}
return destPath, size, nil
}
func (fa *FileAnalyzer) DeleteFile(filePath string) error {
if !strings.HasPrefix(filePath, fa.storagePath) {
return fmt.Errorf("invalid file path")
}
return os.Remove(filePath)
}

View File

@@ -0,0 +1,537 @@
package finance
import (
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strings"
"sync"
"time"
)
type HeatmapService struct {
cache map[string]*CachedHeatmap
mu sync.RWMutex
httpClient *http.Client
config HeatmapConfig
}
type HeatmapConfig struct {
DataProviderURL string
CacheTTL time.Duration
RefreshInterval time.Duration
}
type CachedHeatmap struct {
Data *MarketHeatmap
ExpiresAt time.Time
}
type MarketHeatmap struct {
ID string `json:"id"`
Title string `json:"title"`
Type HeatmapType `json:"type"`
Market string `json:"market"`
Sectors []Sector `json:"sectors"`
Tickers []TickerData `json:"tickers"`
Summary MarketSummary `json:"summary"`
UpdatedAt time.Time `json:"updatedAt"`
TimeRange string `json:"timeRange"`
Colorscale Colorscale `json:"colorscale"`
}
type HeatmapType string
const (
HeatmapTreemap HeatmapType = "treemap"
HeatmapGrid HeatmapType = "grid"
HeatmapBubble HeatmapType = "bubble"
HeatmapSectorChart HeatmapType = "sector_chart"
)
type Sector struct {
ID string `json:"id"`
Name string `json:"name"`
Change float64 `json:"change"`
MarketCap float64 `json:"marketCap"`
Volume float64 `json:"volume"`
TickerCount int `json:"tickerCount"`
TopGainers []TickerData `json:"topGainers,omitempty"`
TopLosers []TickerData `json:"topLosers,omitempty"`
Color string `json:"color"`
Weight float64 `json:"weight"`
}
type TickerData struct {
Symbol string `json:"symbol"`
Name string `json:"name"`
Price float64 `json:"price"`
Change float64 `json:"change"`
ChangePercent float64 `json:"changePercent"`
Volume float64 `json:"volume"`
MarketCap float64 `json:"marketCap"`
Sector string `json:"sector"`
Industry string `json:"industry"`
Color string `json:"color"`
Size float64 `json:"size"`
PrevClose float64 `json:"prevClose,omitempty"`
DayHigh float64 `json:"dayHigh,omitempty"`
DayLow float64 `json:"dayLow,omitempty"`
Week52High float64 `json:"week52High,omitempty"`
Week52Low float64 `json:"week52Low,omitempty"`
PE float64 `json:"pe,omitempty"`
EPS float64 `json:"eps,omitempty"`
Dividend float64 `json:"dividend,omitempty"`
DividendYield float64 `json:"dividendYield,omitempty"`
}
type MarketSummary struct {
TotalMarketCap float64 `json:"totalMarketCap"`
TotalVolume float64 `json:"totalVolume"`
AdvancingCount int `json:"advancingCount"`
DecliningCount int `json:"decliningCount"`
UnchangedCount int `json:"unchangedCount"`
AverageChange float64 `json:"averageChange"`
TopGainer *TickerData `json:"topGainer,omitempty"`
TopLoser *TickerData `json:"topLoser,omitempty"`
MostActive *TickerData `json:"mostActive,omitempty"`
MarketSentiment string `json:"marketSentiment"`
VIX float64 `json:"vix,omitempty"`
FearGreedIndex int `json:"fearGreedIndex,omitempty"`
}
type Colorscale struct {
Min float64 `json:"min"`
Max float64 `json:"max"`
MidPoint float64 `json:"midPoint"`
Colors []string `json:"colors"`
Thresholds []float64 `json:"thresholds"`
}
var DefaultColorscale = Colorscale{
Min: -10,
Max: 10,
MidPoint: 0,
Colors: []string{
"#ef4444",
"#f87171",
"#fca5a5",
"#fecaca",
"#e5e7eb",
"#bbf7d0",
"#86efac",
"#4ade80",
"#22c55e",
},
Thresholds: []float64{-5, -3, -2, -1, 1, 2, 3, 5},
}
func NewHeatmapService(cfg HeatmapConfig) *HeatmapService {
if cfg.CacheTTL == 0 {
cfg.CacheTTL = 5 * time.Minute
}
if cfg.RefreshInterval == 0 {
cfg.RefreshInterval = time.Minute
}
return &HeatmapService{
cache: make(map[string]*CachedHeatmap),
httpClient: &http.Client{Timeout: 30 * time.Second},
config: cfg,
}
}
func (s *HeatmapService) GetMarketHeatmap(ctx context.Context, market string, timeRange string) (*MarketHeatmap, error) {
cacheKey := fmt.Sprintf("%s:%s", market, timeRange)
s.mu.RLock()
if cached, ok := s.cache[cacheKey]; ok && time.Now().Before(cached.ExpiresAt) {
s.mu.RUnlock()
return cached.Data, nil
}
s.mu.RUnlock()
heatmap, err := s.fetchMarketData(ctx, market, timeRange)
if err != nil {
return nil, err
}
s.mu.Lock()
s.cache[cacheKey] = &CachedHeatmap{
Data: heatmap,
ExpiresAt: time.Now().Add(s.config.CacheTTL),
}
s.mu.Unlock()
return heatmap, nil
}
func (s *HeatmapService) GetSectorHeatmap(ctx context.Context, market, sector, timeRange string) (*MarketHeatmap, error) {
heatmap, err := s.GetMarketHeatmap(ctx, market, timeRange)
if err != nil {
return nil, err
}
filteredTickers := make([]TickerData, 0)
for _, t := range heatmap.Tickers {
if strings.EqualFold(t.Sector, sector) {
filteredTickers = append(filteredTickers, t)
}
}
sectorHeatmap := &MarketHeatmap{
ID: fmt.Sprintf("%s-%s", market, sector),
Title: fmt.Sprintf("%s - %s", market, sector),
Type: HeatmapTreemap,
Market: market,
Tickers: filteredTickers,
TimeRange: timeRange,
UpdatedAt: time.Now(),
Colorscale: DefaultColorscale,
}
sectorHeatmap.Summary = s.calculateSummary(filteredTickers)
return sectorHeatmap, nil
}
func (s *HeatmapService) fetchMarketData(ctx context.Context, market, timeRange string) (*MarketHeatmap, error) {
heatmap := s.generateMockMarketData(market)
heatmap.TimeRange = timeRange
return heatmap, nil
}
func (s *HeatmapService) generateMockMarketData(market string) *MarketHeatmap {
sectors := []struct {
name string
tickers []struct{ symbol, name string }
}{
{"Technology", []struct{ symbol, name string }{
{"AAPL", "Apple Inc."},
{"MSFT", "Microsoft Corp."},
{"GOOGL", "Alphabet Inc."},
{"AMZN", "Amazon.com Inc."},
{"META", "Meta Platforms"},
{"NVDA", "NVIDIA Corp."},
{"TSLA", "Tesla Inc."},
}},
{"Healthcare", []struct{ symbol, name string }{
{"JNJ", "Johnson & Johnson"},
{"UNH", "UnitedHealth Group"},
{"PFE", "Pfizer Inc."},
{"MRK", "Merck & Co."},
{"ABBV", "AbbVie Inc."},
}},
{"Finance", []struct{ symbol, name string }{
{"JPM", "JPMorgan Chase"},
{"BAC", "Bank of America"},
{"WFC", "Wells Fargo"},
{"GS", "Goldman Sachs"},
{"MS", "Morgan Stanley"},
}},
{"Energy", []struct{ symbol, name string }{
{"XOM", "Exxon Mobil"},
{"CVX", "Chevron Corp."},
{"COP", "ConocoPhillips"},
{"SLB", "Schlumberger"},
}},
{"Consumer", []struct{ symbol, name string }{
{"WMT", "Walmart Inc."},
{"PG", "Procter & Gamble"},
{"KO", "Coca-Cola Co."},
{"PEP", "PepsiCo Inc."},
{"COST", "Costco Wholesale"},
}},
}
allTickers := make([]TickerData, 0)
allSectors := make([]Sector, 0)
for _, sec := range sectors {
sectorTickers := make([]TickerData, 0)
sectorChange := 0.0
for _, t := range sec.tickers {
change := (randomFloat(-5, 5))
price := randomFloat(50, 500)
marketCap := randomFloat(50e9, 3000e9)
volume := randomFloat(1e6, 100e6)
ticker := TickerData{
Symbol: t.symbol,
Name: t.name,
Price: price,
Change: price * change / 100,
ChangePercent: change,
Volume: volume,
MarketCap: marketCap,
Sector: sec.name,
Color: getColorForChange(change),
Size: math.Log10(marketCap) * 10,
}
sectorTickers = append(sectorTickers, ticker)
sectorChange += change
}
if len(sectorTickers) > 0 {
sectorChange /= float64(len(sectorTickers))
}
sort.Slice(sectorTickers, func(i, j int) bool {
return sectorTickers[i].ChangePercent > sectorTickers[j].ChangePercent
})
var topGainers, topLosers []TickerData
if len(sectorTickers) >= 2 {
topGainers = sectorTickers[:2]
topLosers = sectorTickers[len(sectorTickers)-2:]
}
sectorMarketCap := 0.0
sectorVolume := 0.0
for _, t := range sectorTickers {
sectorMarketCap += t.MarketCap
sectorVolume += t.Volume
}
sector := Sector{
ID: strings.ToLower(strings.ReplaceAll(sec.name, " ", "_")),
Name: sec.name,
Change: sectorChange,
MarketCap: sectorMarketCap,
Volume: sectorVolume,
TickerCount: len(sectorTickers),
TopGainers: topGainers,
TopLosers: topLosers,
Color: getColorForChange(sectorChange),
Weight: sectorMarketCap,
}
allSectors = append(allSectors, sector)
allTickers = append(allTickers, sectorTickers...)
}
sort.Slice(allTickers, func(i, j int) bool {
return allTickers[i].MarketCap > allTickers[j].MarketCap
})
return &MarketHeatmap{
ID: market,
Title: getMarketTitle(market),
Type: HeatmapTreemap,
Market: market,
Sectors: allSectors,
Tickers: allTickers,
Summary: *s.calculateSummaryPtr(allTickers),
UpdatedAt: time.Now(),
Colorscale: DefaultColorscale,
}
}
func (s *HeatmapService) calculateSummary(tickers []TickerData) MarketSummary {
return *s.calculateSummaryPtr(tickers)
}
func (s *HeatmapService) calculateSummaryPtr(tickers []TickerData) *MarketSummary {
summary := &MarketSummary{}
var totalChange float64
var topGainer, topLoser, mostActive *TickerData
for i := range tickers {
t := &tickers[i]
summary.TotalMarketCap += t.MarketCap
summary.TotalVolume += t.Volume
totalChange += t.ChangePercent
if t.ChangePercent > 0 {
summary.AdvancingCount++
} else if t.ChangePercent < 0 {
summary.DecliningCount++
} else {
summary.UnchangedCount++
}
if topGainer == nil || t.ChangePercent > topGainer.ChangePercent {
topGainer = t
}
if topLoser == nil || t.ChangePercent < topLoser.ChangePercent {
topLoser = t
}
if mostActive == nil || t.Volume > mostActive.Volume {
mostActive = t
}
}
if len(tickers) > 0 {
summary.AverageChange = totalChange / float64(len(tickers))
}
summary.TopGainer = topGainer
summary.TopLoser = topLoser
summary.MostActive = mostActive
if summary.AverageChange > 1 {
summary.MarketSentiment = "bullish"
} else if summary.AverageChange < -1 {
summary.MarketSentiment = "bearish"
} else {
summary.MarketSentiment = "neutral"
}
return summary
}
func (s *HeatmapService) GenerateTreemapData(heatmap *MarketHeatmap) interface{} {
children := make([]map[string]interface{}, 0)
for _, sector := range heatmap.Sectors {
sectorChildren := make([]map[string]interface{}, 0)
for _, ticker := range heatmap.Tickers {
if ticker.Sector == sector.Name {
sectorChildren = append(sectorChildren, map[string]interface{}{
"name": ticker.Symbol,
"value": ticker.MarketCap,
"change": ticker.ChangePercent,
"color": ticker.Color,
"data": ticker,
})
}
}
children = append(children, map[string]interface{}{
"name": sector.Name,
"children": sectorChildren,
"change": sector.Change,
"color": sector.Color,
})
}
return map[string]interface{}{
"name": heatmap.Market,
"children": children,
}
}
func (s *HeatmapService) GenerateGridData(heatmap *MarketHeatmap, rows, cols int) [][]TickerData {
grid := make([][]TickerData, rows)
for i := range grid {
grid[i] = make([]TickerData, cols)
}
idx := 0
for i := 0; i < rows && idx < len(heatmap.Tickers); i++ {
for j := 0; j < cols && idx < len(heatmap.Tickers); j++ {
grid[i][j] = heatmap.Tickers[idx]
idx++
}
}
return grid
}
func (s *HeatmapService) GetTopMovers(ctx context.Context, market string, count int) (*TopMovers, error) {
heatmap, err := s.GetMarketHeatmap(ctx, market, "1d")
if err != nil {
return nil, err
}
tickers := make([]TickerData, len(heatmap.Tickers))
copy(tickers, heatmap.Tickers)
sort.Slice(tickers, func(i, j int) bool {
return tickers[i].ChangePercent > tickers[j].ChangePercent
})
gainers := tickers
if len(gainers) > count {
gainers = gainers[:count]
}
sort.Slice(tickers, func(i, j int) bool {
return tickers[i].ChangePercent < tickers[j].ChangePercent
})
losers := tickers
if len(losers) > count {
losers = losers[:count]
}
sort.Slice(tickers, func(i, j int) bool {
return tickers[i].Volume > tickers[j].Volume
})
active := tickers
if len(active) > count {
active = active[:count]
}
return &TopMovers{
Gainers: gainers,
Losers: losers,
MostActive: active,
UpdatedAt: time.Now(),
}, nil
}
type TopMovers struct {
Gainers []TickerData `json:"gainers"`
Losers []TickerData `json:"losers"`
MostActive []TickerData `json:"mostActive"`
UpdatedAt time.Time `json:"updatedAt"`
}
func getColorForChange(change float64) string {
if change >= 5 {
return "#22c55e"
} else if change >= 3 {
return "#4ade80"
} else if change >= 1 {
return "#86efac"
} else if change >= 0 {
return "#bbf7d0"
} else if change >= -1 {
return "#fecaca"
} else if change >= -3 {
return "#fca5a5"
} else if change >= -5 {
return "#f87171"
}
return "#ef4444"
}
func getMarketTitle(market string) string {
titles := map[string]string{
"sp500": "S&P 500",
"nasdaq": "NASDAQ",
"dow": "Dow Jones",
"moex": "MOEX",
"crypto": "Cryptocurrency",
"forex": "Forex",
"commodities": "Commodities",
}
if title, ok := titles[strings.ToLower(market)]; ok {
return title
}
return market
}
var rng uint64 = uint64(time.Now().UnixNano())
func randomFloat(min, max float64) float64 {
rng ^= rng << 13
rng ^= rng >> 17
rng ^= rng << 5
f := float64(rng) / float64(1<<64)
return min + f*(max-min)
}
func (h *MarketHeatmap) ToJSON() ([]byte, error) {
return json.Marshal(h)
}

View File

@@ -0,0 +1,759 @@
package labs
import (
"context"
"encoding/json"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
)
type Generator struct {
llm llm.Client
}
func NewGenerator(llmClient llm.Client) *Generator {
return &Generator{llm: llmClient}
}
type GenerateOptions struct {
Query string
Data interface{}
PreferredTypes []VisualizationType
Theme string
Locale string
MaxVisualizations int
}
func (g *Generator) GenerateReport(ctx context.Context, opts GenerateOptions) (*Report, error) {
analysisPrompt := fmt.Sprintf(`Analyze this data and query to determine the best visualizations.
Query: %s
Data: %v
Determine:
1. What visualizations would best represent this data?
2. How should the data be structured for each visualization?
3. What insights can be highlighted?
Respond in JSON format:
{
"title": "Report title",
"sections": [
{
"title": "Section title",
"visualizations": [
{
"type": "chart_type",
"title": "Viz title",
"dataMapping": { "how to map the data" },
"insight": "Key insight"
}
]
}
]
}
Available visualization types: bar_chart, line_chart, pie_chart, donut_chart, table, stat_cards, kpi, comparison, timeline, progress, heatmap, code_block, markdown, collapsible, tabs, accordion`, opts.Query, opts.Data)
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: analysisPrompt}},
})
if err != nil {
return nil, err
}
var analysis struct {
Title string `json:"title"`
Sections []struct {
Title string `json:"title"`
Visualizations []struct {
Type string `json:"type"`
Title string `json:"title"`
DataMapping map[string]interface{} `json:"dataMapping"`
Insight string `json:"insight"`
} `json:"visualizations"`
} `json:"sections"`
}
jsonStr := extractJSON(result)
if err := json.Unmarshal([]byte(jsonStr), &analysis); err != nil {
return g.createDefaultReport(opts)
}
report := &Report{
ID: uuid.New().String(),
Title: analysis.Title,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Theme: opts.Theme,
Sections: make([]ReportSection, 0),
}
for _, sec := range analysis.Sections {
section := ReportSection{
ID: uuid.New().String(),
Title: sec.Title,
Visualizations: make([]Visualization, 0),
}
for _, viz := range sec.Visualizations {
visualization := g.createVisualization(VisualizationType(viz.Type), viz.Title, opts.Data, viz.DataMapping)
if visualization != nil {
section.Visualizations = append(section.Visualizations, *visualization)
}
}
if len(section.Visualizations) > 0 {
report.Sections = append(report.Sections, section)
}
}
return report, nil
}
func (g *Generator) createDefaultReport(opts GenerateOptions) (*Report, error) {
report := &Report{
ID: uuid.New().String(),
Title: "Анализ данных",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Sections: []ReportSection{
{
ID: uuid.New().String(),
Title: "Обзор",
Visualizations: []Visualization{
g.CreateMarkdown("", formatDataAsMarkdown(opts.Data)),
},
},
},
}
return report, nil
}
func (g *Generator) createVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
switch vizType {
case VizBarChart, VizLineChart, VizAreaChart:
return g.createChartVisualization(vizType, title, data, mapping)
case VizPieChart, VizDonutChart:
return g.createPieVisualization(vizType, title, data, mapping)
case VizTable:
return g.createTableVisualization(title, data, mapping)
case VizStatCards:
return g.createStatCardsVisualization(title, data, mapping)
case VizKPI:
return g.createKPIVisualization(title, data, mapping)
case VizTimeline:
return g.createTimelineVisualization(title, data, mapping)
case VizComparison:
return g.createComparisonVisualization(title, data, mapping)
case VizProgress:
return g.createProgressVisualization(title, data, mapping)
case VizMarkdown:
content := extractStringFromData(data, mapping, "content")
viz := g.CreateMarkdown(title, content)
return &viz
default:
viz := g.CreateMarkdown(title, formatDataAsMarkdown(data))
return &viz
}
}
func (g *Generator) createChartVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
chartData := &ChartData{
Labels: make([]string, 0),
Datasets: make([]ChartDataset, 0),
}
if dataMap, ok := data.(map[string]interface{}); ok {
labels := make([]string, 0)
values := make([]float64, 0)
for k, v := range dataMap {
labels = append(labels, k)
values = append(values, toFloat64(v))
}
chartData.Labels = labels
chartData.Datasets = append(chartData.Datasets, ChartDataset{
Label: title,
Data: values,
})
}
if dataSlice, ok := data.([]interface{}); ok {
for _, item := range dataSlice {
if itemMap, ok := item.(map[string]interface{}); ok {
if label, ok := itemMap["label"].(string); ok {
chartData.Labels = append(chartData.Labels, label)
}
if value, ok := itemMap["value"]; ok {
if len(chartData.Datasets) == 0 {
chartData.Datasets = append(chartData.Datasets, ChartDataset{Label: title, Data: []float64{}})
}
chartData.Datasets[0].Data = append(chartData.Datasets[0].Data, toFloat64(value))
}
}
}
}
return &Visualization{
ID: uuid.New().String(),
Type: vizType,
Title: title,
Data: chartData,
Config: VisualizationConfig{
ShowLegend: true,
ShowTooltip: true,
ShowGrid: true,
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) createPieVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
chartData := &ChartData{
Labels: make([]string, 0),
Datasets: make([]ChartDataset, 0),
}
dataset := ChartDataset{Label: title, Data: []float64{}}
if dataMap, ok := data.(map[string]interface{}); ok {
for k, v := range dataMap {
chartData.Labels = append(chartData.Labels, k)
dataset.Data = append(dataset.Data, toFloat64(v))
}
}
chartData.Datasets = append(chartData.Datasets, dataset)
return &Visualization{
ID: uuid.New().String(),
Type: vizType,
Title: title,
Data: chartData,
Config: VisualizationConfig{
ShowLegend: true,
ShowTooltip: true,
ShowValues: true,
Animated: true,
},
Style: VisualizationStyle{
Height: "300px",
},
Responsive: true,
}
}
func (g *Generator) createTableVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
tableData := &TableData{
Columns: make([]TableColumn, 0),
Rows: make([]TableRow, 0),
}
if dataSlice, ok := data.([]interface{}); ok && len(dataSlice) > 0 {
if firstRow, ok := dataSlice[0].(map[string]interface{}); ok {
for key := range firstRow {
tableData.Columns = append(tableData.Columns, TableColumn{
Key: key,
Label: formatColumnLabel(key),
Sortable: true,
})
}
}
for _, item := range dataSlice {
if rowMap, ok := item.(map[string]interface{}); ok {
tableData.Rows = append(tableData.Rows, TableRow(rowMap))
}
}
}
if dataMap, ok := data.(map[string]interface{}); ok {
tableData.Columns = []TableColumn{
{Key: "key", Label: "Параметр", Sortable: true},
{Key: "value", Label: "Значение", Sortable: true},
}
for k, v := range dataMap {
tableData.Rows = append(tableData.Rows, TableRow{
"key": k,
"value": v,
})
}
}
tableData.Summary = &TableSummary{
TotalRows: len(tableData.Rows),
}
return &Visualization{
ID: uuid.New().String(),
Type: VizTable,
Title: title,
Data: tableData,
Config: VisualizationConfig{
Sortable: true,
Searchable: true,
Paginated: len(tableData.Rows) > 10,
PageSize: 10,
},
Responsive: true,
}
}
func (g *Generator) createStatCardsVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
cardsData := &StatCardsData{
Cards: make([]StatCard, 0),
}
colors := []string{"#3B82F6", "#10B981", "#F59E0B", "#EF4444", "#8B5CF6", "#EC4899"}
colorIdx := 0
if dataMap, ok := data.(map[string]interface{}); ok {
for k, v := range dataMap {
card := StatCard{
ID: uuid.New().String(),
Title: formatColumnLabel(k),
Value: v,
Color: colors[colorIdx%len(colors)],
}
cardsData.Cards = append(cardsData.Cards, card)
colorIdx++
}
}
return &Visualization{
ID: uuid.New().String(),
Type: VizStatCards,
Title: title,
Data: cardsData,
Config: VisualizationConfig{
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) createKPIVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
kpiData := &KPIData{
Value: data,
}
if dataMap, ok := data.(map[string]interface{}); ok {
if v, ok := dataMap["value"]; ok {
kpiData.Value = v
}
if v, ok := dataMap["change"].(float64); ok {
kpiData.Change = v
if v >= 0 {
kpiData.ChangeType = "increase"
} else {
kpiData.ChangeType = "decrease"
}
}
if v, ok := dataMap["target"]; ok {
kpiData.Target = v
}
if v, ok := dataMap["unit"].(string); ok {
kpiData.Unit = v
}
}
return &Visualization{
ID: uuid.New().String(),
Type: VizKPI,
Title: title,
Data: kpiData,
Config: VisualizationConfig{
Animated: true,
ShowValues: true,
},
Style: VisualizationStyle{
MinHeight: "150px",
},
Responsive: true,
}
}
func (g *Generator) createTimelineVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
timelineData := &TimelineData{
Events: make([]TimelineEvent, 0),
}
if dataSlice, ok := data.([]interface{}); ok {
for _, item := range dataSlice {
if itemMap, ok := item.(map[string]interface{}); ok {
event := TimelineEvent{
ID: uuid.New().String(),
}
if v, ok := itemMap["date"].(string); ok {
event.Date, _ = time.Parse(time.RFC3339, v)
}
if v, ok := itemMap["title"].(string); ok {
event.Title = v
}
if v, ok := itemMap["description"].(string); ok {
event.Description = v
}
timelineData.Events = append(timelineData.Events, event)
}
}
}
sort.Slice(timelineData.Events, func(i, j int) bool {
return timelineData.Events[i].Date.Before(timelineData.Events[j].Date)
})
return &Visualization{
ID: uuid.New().String(),
Type: VizTimeline,
Title: title,
Data: timelineData,
Config: VisualizationConfig{
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) createComparisonVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
compData := &ComparisonData{
Items: make([]ComparisonItem, 0),
Categories: make([]string, 0),
}
if dataSlice, ok := data.([]interface{}); ok && len(dataSlice) > 0 {
if firstItem, ok := dataSlice[0].(map[string]interface{}); ok {
for k := range firstItem {
if k != "name" && k != "id" && k != "image" {
compData.Categories = append(compData.Categories, k)
}
}
}
for _, item := range dataSlice {
if itemMap, ok := item.(map[string]interface{}); ok {
compItem := ComparisonItem{
ID: uuid.New().String(),
Values: make(map[string]interface{}),
}
if v, ok := itemMap["name"].(string); ok {
compItem.Name = v
}
if v, ok := itemMap["image"].(string); ok {
compItem.Image = v
}
for _, cat := range compData.Categories {
if v, ok := itemMap[cat]; ok {
compItem.Values[cat] = v
}
}
compData.Items = append(compData.Items, compItem)
}
}
}
return &Visualization{
ID: uuid.New().String(),
Type: VizComparison,
Title: title,
Data: compData,
Config: VisualizationConfig{
ShowLabels: true,
},
Responsive: true,
}
}
func (g *Generator) createProgressVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
progressData := &ProgressData{
Current: 0,
Total: 100,
ShowValue: true,
Animated: true,
}
if dataMap, ok := data.(map[string]interface{}); ok {
if v, ok := dataMap["current"]; ok {
progressData.Current = toFloat64(v)
}
if v, ok := dataMap["total"]; ok {
progressData.Total = toFloat64(v)
}
if v, ok := dataMap["label"].(string); ok {
progressData.Label = v
}
}
if v, ok := data.(float64); ok {
progressData.Current = v
}
return &Visualization{
ID: uuid.New().String(),
Type: VizProgress,
Title: title,
Data: progressData,
Config: VisualizationConfig{
Animated: true,
ShowValues: true,
},
Responsive: true,
}
}
func (g *Generator) CreateBarChart(title string, labels []string, values []float64) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizBarChart,
Title: title,
Data: &ChartData{
Labels: labels,
Datasets: []ChartDataset{
{Label: title, Data: values},
},
},
Config: VisualizationConfig{
ShowLegend: true,
ShowTooltip: true,
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) CreateLineChart(title string, labels []string, datasets []ChartDataset) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizLineChart,
Title: title,
Data: &ChartData{
Labels: labels,
Datasets: datasets,
},
Config: VisualizationConfig{
ShowLegend: true,
ShowTooltip: true,
ShowGrid: true,
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) CreatePieChart(title string, labels []string, values []float64) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizPieChart,
Title: title,
Data: &ChartData{
Labels: labels,
Datasets: []ChartDataset{
{Label: title, Data: values},
},
},
Config: VisualizationConfig{
ShowLegend: true,
ShowTooltip: true,
ShowValues: true,
},
Responsive: true,
}
}
func (g *Generator) CreateTable(title string, columns []TableColumn, rows []TableRow) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizTable,
Title: title,
Data: &TableData{
Columns: columns,
Rows: rows,
Summary: &TableSummary{TotalRows: len(rows)},
},
Config: VisualizationConfig{
Sortable: true,
Searchable: true,
Paginated: len(rows) > 10,
PageSize: 10,
},
Responsive: true,
}
}
func (g *Generator) CreateStatCards(title string, cards []StatCard) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizStatCards,
Title: title,
Data: &StatCardsData{Cards: cards},
Config: VisualizationConfig{
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) CreateKPI(title string, value interface{}, change float64, unit string) Visualization {
changeType := "neutral"
if change > 0 {
changeType = "increase"
} else if change < 0 {
changeType = "decrease"
}
return Visualization{
ID: uuid.New().String(),
Type: VizKPI,
Title: title,
Data: &KPIData{
Value: value,
Change: change,
ChangeType: changeType,
Unit: unit,
},
Config: VisualizationConfig{
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) CreateMarkdown(title string, content string) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizMarkdown,
Title: title,
Data: &MarkdownData{Content: content},
Responsive: true,
}
}
func (g *Generator) CreateCodeBlock(title, code, language string) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizCodeBlock,
Title: title,
Data: &CodeBlockData{
Code: code,
Language: language,
ShowLineNum: true,
Copyable: true,
},
Responsive: true,
}
}
func (g *Generator) CreateTabs(title string, tabs []TabItem) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizTabs,
Title: title,
Data: &TabsData{Tabs: tabs},
Responsive: true,
}
}
func (g *Generator) CreateAccordion(title string, items []AccordionItem) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizAccordion,
Title: title,
Data: &AccordionData{Items: items},
Config: VisualizationConfig{
Animated: true,
},
Responsive: true,
}
}
func (g *Generator) CreateHeatmap(title string, xLabels, yLabels []string, values [][]float64) Visualization {
return Visualization{
ID: uuid.New().String(),
Type: VizHeatmap,
Title: title,
Data: &HeatmapData{
XLabels: xLabels,
YLabels: yLabels,
Values: values,
},
Config: VisualizationConfig{
ShowTooltip: true,
ShowLabels: true,
},
Responsive: true,
}
}
func extractJSON(text string) string {
re := regexp.MustCompile(`(?s)\{.*\}`)
match := re.FindString(text)
if match != "" {
return match
}
return "{}"
}
func toFloat64(v interface{}) float64 {
switch val := v.(type) {
case float64:
return val
case float32:
return float64(val)
case int:
return float64(val)
case int64:
return float64(val)
case string:
f, _ := strconv.ParseFloat(val, 64)
return f
default:
return 0
}
}
func formatColumnLabel(key string) string {
key = strings.ReplaceAll(key, "_", " ")
key = strings.ReplaceAll(key, "-", " ")
words := strings.Fields(key)
for i, word := range words {
if len(word) > 0 {
words[i] = strings.ToUpper(string(word[0])) + strings.ToLower(word[1:])
}
}
return strings.Join(words, " ")
}
func extractStringFromData(data interface{}, mapping map[string]interface{}, key string) string {
if dataMap, ok := data.(map[string]interface{}); ok {
if v, ok := dataMap[key].(string); ok {
return v
}
}
return fmt.Sprintf("%v", data)
}
func formatDataAsMarkdown(data interface{}) string {
jsonBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Sprintf("%v", data)
}
return "```json\n" + string(jsonBytes) + "\n```"
}

View File

@@ -0,0 +1,335 @@
package labs
import "time"
type VisualizationType string
const (
VizBarChart VisualizationType = "bar_chart"
VizLineChart VisualizationType = "line_chart"
VizPieChart VisualizationType = "pie_chart"
VizDonutChart VisualizationType = "donut_chart"
VizAreaChart VisualizationType = "area_chart"
VizScatterPlot VisualizationType = "scatter_plot"
VizHeatmap VisualizationType = "heatmap"
VizTreemap VisualizationType = "treemap"
VizGauge VisualizationType = "gauge"
VizRadar VisualizationType = "radar"
VizSankey VisualizationType = "sankey"
VizTable VisualizationType = "table"
VizTimeline VisualizationType = "timeline"
VizKPI VisualizationType = "kpi"
VizProgress VisualizationType = "progress"
VizComparison VisualizationType = "comparison"
VizStatCards VisualizationType = "stat_cards"
VizMap VisualizationType = "map"
VizFlowChart VisualizationType = "flow_chart"
VizOrgChart VisualizationType = "org_chart"
VizCodeBlock VisualizationType = "code_block"
VizMarkdown VisualizationType = "markdown"
VizCollapsible VisualizationType = "collapsible"
VizTabs VisualizationType = "tabs"
VizAccordion VisualizationType = "accordion"
VizStepper VisualizationType = "stepper"
VizForm VisualizationType = "form"
)
type Visualization struct {
ID string `json:"id"`
Type VisualizationType `json:"type"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Data interface{} `json:"data"`
Config VisualizationConfig `json:"config,omitempty"`
Style VisualizationStyle `json:"style,omitempty"`
Actions []VisualizationAction `json:"actions,omitempty"`
Responsive bool `json:"responsive"`
}
type VisualizationConfig struct {
ShowLegend bool `json:"showLegend,omitempty"`
ShowGrid bool `json:"showGrid,omitempty"`
ShowTooltip bool `json:"showTooltip,omitempty"`
ShowLabels bool `json:"showLabels,omitempty"`
ShowValues bool `json:"showValues,omitempty"`
Animated bool `json:"animated,omitempty"`
Stacked bool `json:"stacked,omitempty"`
Horizontal bool `json:"horizontal,omitempty"`
Sortable bool `json:"sortable,omitempty"`
Filterable bool `json:"filterable,omitempty"`
Searchable bool `json:"searchable,omitempty"`
Paginated bool `json:"paginated,omitempty"`
PageSize int `json:"pageSize,omitempty"`
Expandable bool `json:"expandable,omitempty"`
DefaultExpanded bool `json:"defaultExpanded,omitempty"`
XAxisLabel string `json:"xAxisLabel,omitempty"`
YAxisLabel string `json:"yAxisLabel,omitempty"`
Colors []string `json:"colors,omitempty"`
DateFormat string `json:"dateFormat,omitempty"`
NumberFormat string `json:"numberFormat,omitempty"`
CurrencySymbol string `json:"currencySymbol,omitempty"`
}
type VisualizationStyle struct {
Width string `json:"width,omitempty"`
Height string `json:"height,omitempty"`
MinHeight string `json:"minHeight,omitempty"`
MaxHeight string `json:"maxHeight,omitempty"`
Padding string `json:"padding,omitempty"`
Margin string `json:"margin,omitempty"`
BorderRadius string `json:"borderRadius,omitempty"`
Background string `json:"background,omitempty"`
Shadow string `json:"shadow,omitempty"`
FontFamily string `json:"fontFamily,omitempty"`
FontSize string `json:"fontSize,omitempty"`
TextColor string `json:"textColor,omitempty"`
AccentColor string `json:"accentColor,omitempty"`
GridColor string `json:"gridColor,omitempty"`
}
type VisualizationAction struct {
ID string `json:"id"`
Label string `json:"label"`
Icon string `json:"icon,omitempty"`
Type string `json:"type"`
Handler string `json:"handler,omitempty"`
URL string `json:"url,omitempty"`
}
type ChartData struct {
Labels []string `json:"labels"`
Datasets []ChartDataset `json:"datasets"`
}
type ChartDataset struct {
Label string `json:"label"`
Data []float64 `json:"data"`
BackgroundColor string `json:"backgroundColor,omitempty"`
BorderColor string `json:"borderColor,omitempty"`
Fill bool `json:"fill,omitempty"`
}
type TableData struct {
Columns []TableColumn `json:"columns"`
Rows []TableRow `json:"rows"`
Summary *TableSummary `json:"summary,omitempty"`
}
type TableColumn struct {
Key string `json:"key"`
Label string `json:"label"`
Type string `json:"type,omitempty"`
Width string `json:"width,omitempty"`
Sortable bool `json:"sortable,omitempty"`
Align string `json:"align,omitempty"`
Format string `json:"format,omitempty"`
Highlight bool `json:"highlight,omitempty"`
}
type TableRow map[string]interface{}
type TableSummary struct {
TotalRows int `json:"totalRows"`
Aggregations map[string]interface{} `json:"aggregations,omitempty"`
}
type TimelineData struct {
Events []TimelineEvent `json:"events"`
}
type TimelineEvent struct {
ID string `json:"id"`
Date time.Time `json:"date"`
Title string `json:"title"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
Color string `json:"color,omitempty"`
Link string `json:"link,omitempty"`
}
type KPIData struct {
Value interface{} `json:"value"`
PrevValue interface{} `json:"prevValue,omitempty"`
Change float64 `json:"change,omitempty"`
ChangeType string `json:"changeType,omitempty"`
Unit string `json:"unit,omitempty"`
Prefix string `json:"prefix,omitempty"`
Suffix string `json:"suffix,omitempty"`
Target interface{} `json:"target,omitempty"`
Trend []float64 `json:"trend,omitempty"`
Icon string `json:"icon,omitempty"`
Color string `json:"color,omitempty"`
Description string `json:"description,omitempty"`
}
type StatCardsData struct {
Cards []StatCard `json:"cards"`
}
type StatCard struct {
ID string `json:"id"`
Title string `json:"title"`
Value interface{} `json:"value"`
Change float64 `json:"change,omitempty"`
ChangeLabel string `json:"changeLabel,omitempty"`
Icon string `json:"icon,omitempty"`
Color string `json:"color,omitempty"`
Sparkline []float64 `json:"sparkline,omitempty"`
}
type ComparisonData struct {
Items []ComparisonItem `json:"items"`
Categories []string `json:"categories"`
}
type ComparisonItem struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image,omitempty"`
Values map[string]interface{} `json:"values"`
}
type ProgressData struct {
Current float64 `json:"current"`
Total float64 `json:"total"`
Label string `json:"label,omitempty"`
Color string `json:"color,omitempty"`
ShowValue bool `json:"showValue,omitempty"`
Animated bool `json:"animated,omitempty"`
}
type HeatmapData struct {
XLabels []string `json:"xLabels"`
YLabels []string `json:"yLabels"`
Values [][]float64 `json:"values"`
Min float64 `json:"min,omitempty"`
Max float64 `json:"max,omitempty"`
}
type MapData struct {
Center []float64 `json:"center"`
Zoom int `json:"zoom"`
Markers []MapMarker `json:"markers,omitempty"`
Regions []MapRegion `json:"regions,omitempty"`
}
type MapMarker struct {
ID string `json:"id"`
Position []float64 `json:"position"`
Label string `json:"label,omitempty"`
Icon string `json:"icon,omitempty"`
Color string `json:"color,omitempty"`
Popup string `json:"popup,omitempty"`
}
type MapRegion struct {
ID string `json:"id"`
Name string `json:"name"`
Value float64 `json:"value"`
Color string `json:"color,omitempty"`
}
type CollapsibleData struct {
Title string `json:"title"`
Content interface{} `json:"content"`
DefaultOpen bool `json:"defaultOpen,omitempty"`
Icon string `json:"icon,omitempty"`
Children []Visualization `json:"children,omitempty"`
}
type TabsData struct {
Tabs []TabItem `json:"tabs"`
}
type TabItem struct {
ID string `json:"id"`
Label string `json:"label"`
Icon string `json:"icon,omitempty"`
Content interface{} `json:"content"`
Children []Visualization `json:"children,omitempty"`
}
type AccordionData struct {
Items []AccordionItem `json:"items"`
}
type AccordionItem struct {
ID string `json:"id"`
Title string `json:"title"`
Content interface{} `json:"content"`
Icon string `json:"icon,omitempty"`
Open bool `json:"open,omitempty"`
}
type StepperData struct {
Steps []StepperStep `json:"steps"`
CurrentStep int `json:"currentStep"`
Orientation string `json:"orientation,omitempty"`
}
type StepperStep struct {
ID string `json:"id"`
Label string `json:"label"`
Description string `json:"description,omitempty"`
Content interface{} `json:"content,omitempty"`
Status string `json:"status,omitempty"`
Icon string `json:"icon,omitempty"`
}
type FormData struct {
Fields []FormField `json:"fields"`
SubmitLabel string `json:"submitLabel,omitempty"`
Layout string `json:"layout,omitempty"`
}
type FormField struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
Placeholder string `json:"placeholder,omitempty"`
Value interface{} `json:"value,omitempty"`
Options []FormOption `json:"options,omitempty"`
Required bool `json:"required,omitempty"`
Validation string `json:"validation,omitempty"`
}
type FormOption struct {
Value string `json:"value"`
Label string `json:"label"`
}
type CodeBlockData struct {
Code string `json:"code"`
Language string `json:"language"`
Filename string `json:"filename,omitempty"`
Highlight []int `json:"highlight,omitempty"`
ShowLineNum bool `json:"showLineNum,omitempty"`
Copyable bool `json:"copyable,omitempty"`
}
type MarkdownData struct {
Content string `json:"content"`
}
type Report struct {
ID string `json:"id"`
Title string `json:"title"`
Description string `json:"description,omitempty"`
Sections []ReportSection `json:"sections"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Author string `json:"author,omitempty"`
Tags []string `json:"tags,omitempty"`
IsPublic bool `json:"isPublic"`
Theme string `json:"theme,omitempty"`
CustomCSS string `json:"customCss,omitempty"`
}
type ReportSection struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Visualizations []Visualization `json:"visualizations"`
Layout string `json:"layout,omitempty"`
Columns int `json:"columns,omitempty"`
}

View File

@@ -0,0 +1,701 @@
package learning
import (
"context"
"encoding/json"
"fmt"
"regexp"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
)
type LearningMode string
const (
ModeExplain LearningMode = "explain"
ModeGuided LearningMode = "guided"
ModeInteractive LearningMode = "interactive"
ModePractice LearningMode = "practice"
ModeQuiz LearningMode = "quiz"
)
type DifficultyLevel string
const (
DifficultyBeginner DifficultyLevel = "beginner"
DifficultyIntermediate DifficultyLevel = "intermediate"
DifficultyAdvanced DifficultyLevel = "advanced"
DifficultyExpert DifficultyLevel = "expert"
)
type StepByStepLesson struct {
ID string `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
Topic string `json:"topic"`
Difficulty DifficultyLevel `json:"difficulty"`
Mode LearningMode `json:"mode"`
Steps []LearningStep `json:"steps"`
Prerequisites []string `json:"prerequisites,omitempty"`
LearningGoals []string `json:"learningGoals"`
EstimatedTime int `json:"estimatedTimeMinutes"`
Progress LessonProgress `json:"progress"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type LearningStep struct {
ID string `json:"id"`
Number int `json:"number"`
Title string `json:"title"`
Type StepType `json:"type"`
Content StepContent `json:"content"`
Interaction *StepInteraction `json:"interaction,omitempty"`
Hints []string `json:"hints,omitempty"`
Examples []Example `json:"examples,omitempty"`
Practice *PracticeExercise `json:"practice,omitempty"`
Quiz *QuizQuestion `json:"quiz,omitempty"`
Duration int `json:"durationSeconds,omitempty"`
Status StepStatus `json:"status"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type StepType string
const (
StepExplanation StepType = "explanation"
StepVisualization StepType = "visualization"
StepCode StepType = "code"
StepInteractive StepType = "interactive"
StepPractice StepType = "practice"
StepQuiz StepType = "quiz"
StepSummary StepType = "summary"
StepCheckpoint StepType = "checkpoint"
)
type StepStatus string
const (
StatusLocked StepStatus = "locked"
StatusAvailable StepStatus = "available"
StatusInProgress StepStatus = "in_progress"
StatusCompleted StepStatus = "completed"
StatusSkipped StepStatus = "skipped"
)
type StepContent struct {
Text string `json:"text"`
Markdown string `json:"markdown,omitempty"`
HTML string `json:"html,omitempty"`
Code *CodeContent `json:"code,omitempty"`
Visualization *VisualizationContent `json:"visualization,omitempty"`
Media *MediaContent `json:"media,omitempty"`
Formula string `json:"formula,omitempty"`
Highlights []TextHighlight `json:"highlights,omitempty"`
}
type CodeContent struct {
Language string `json:"language"`
Code string `json:"code"`
Filename string `json:"filename,omitempty"`
Runnable bool `json:"runnable"`
Editable bool `json:"editable"`
Highlights []int `json:"highlights,omitempty"`
Annotations []CodeAnnotation `json:"annotations,omitempty"`
}
type CodeAnnotation struct {
Line int `json:"line"`
Text string `json:"text"`
Type string `json:"type"`
}
type VisualizationContent struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Config map[string]interface{} `json:"config,omitempty"`
}
type MediaContent struct {
Type string `json:"type"`
URL string `json:"url"`
Caption string `json:"caption,omitempty"`
Duration int `json:"duration,omitempty"`
}
type TextHighlight struct {
Start int `json:"start"`
End int `json:"end"`
Text string `json:"text"`
Type string `json:"type"`
Note string `json:"note,omitempty"`
}
type StepInteraction struct {
Type string `json:"type"`
Prompt string `json:"prompt"`
Options []Option `json:"options,omitempty"`
Validation *Validation `json:"validation,omitempty"`
Feedback *Feedback `json:"feedback,omitempty"`
}
type Option struct {
ID string `json:"id"`
Text string `json:"text"`
IsCorrect bool `json:"isCorrect,omitempty"`
Feedback string `json:"feedback,omitempty"`
}
type Validation struct {
Type string `json:"type"`
Pattern string `json:"pattern,omitempty"`
Expected string `json:"expected,omitempty"`
Keywords []string `json:"keywords,omitempty"`
}
type Feedback struct {
Correct string `json:"correct"`
Incorrect string `json:"incorrect"`
Partial string `json:"partial,omitempty"`
}
type Example struct {
Title string `json:"title"`
Description string `json:"description"`
Input string `json:"input,omitempty"`
Output string `json:"output,omitempty"`
Code string `json:"code,omitempty"`
Language string `json:"language,omitempty"`
}
type PracticeExercise struct {
Prompt string `json:"prompt"`
Instructions string `json:"instructions"`
Starter string `json:"starter,omitempty"`
Solution string `json:"solution,omitempty"`
TestCases []TestCase `json:"testCases,omitempty"`
Hints []string `json:"hints,omitempty"`
}
type TestCase struct {
Input string `json:"input"`
Expected string `json:"expected"`
Hidden bool `json:"hidden,omitempty"`
}
type QuizQuestion struct {
Question string `json:"question"`
Type string `json:"type"`
Options []Option `json:"options,omitempty"`
CorrectIndex []int `json:"correctIndex,omitempty"`
Explanation string `json:"explanation"`
Points int `json:"points"`
}
type LessonProgress struct {
CurrentStep int `json:"currentStep"`
CompletedSteps []int `json:"completedSteps"`
Score int `json:"score"`
MaxScore int `json:"maxScore"`
TimeSpent int `json:"timeSpentSeconds"`
StartedAt time.Time `json:"startedAt"`
LastAccessed time.Time `json:"lastAccessed"`
Completed bool `json:"completed"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
}
type LearningGenerator struct {
llm llm.Client
}
func NewLearningGenerator(llmClient llm.Client) *LearningGenerator {
return &LearningGenerator{llm: llmClient}
}
type GenerateLessonOptions struct {
Topic string
Query string
Difficulty DifficultyLevel
Mode LearningMode
MaxSteps int
Locale string
IncludeCode bool
IncludeQuiz bool
}
func (g *LearningGenerator) GenerateLesson(ctx context.Context, opts GenerateLessonOptions) (*StepByStepLesson, error) {
if opts.MaxSteps == 0 {
opts.MaxSteps = 10
}
if opts.Difficulty == "" {
opts.Difficulty = DifficultyBeginner
}
if opts.Mode == "" {
opts.Mode = ModeExplain
}
langInstruction := ""
if opts.Locale == "ru" {
langInstruction = "Generate all content in Russian language."
}
prompt := fmt.Sprintf(`Create a step-by-step educational lesson on the following topic.
Topic: %s
Query: %s
Difficulty: %s
Mode: %s
Max Steps: %d
Include Code Examples: %v
Include Quiz: %v
%s
Generate a structured lesson with these requirements:
1. Break down the concept into clear, digestible steps
2. Each step should build on the previous one
3. Include explanations, examples, and visualizations where helpful
4. For code topics, include runnable code snippets
5. Add practice exercises for interactive learning
6. Include quiz questions to test understanding
Respond in this JSON format:
{
"title": "Lesson title",
"description": "Brief description",
"learningGoals": ["Goal 1", "Goal 2"],
"estimatedTimeMinutes": 15,
"steps": [
{
"title": "Step title",
"type": "explanation|code|interactive|practice|quiz|summary",
"content": {
"text": "Main explanation",
"markdown": "## Formatted content",
"code": {"language": "python", "code": "example", "runnable": true},
"formula": "optional LaTeX formula"
},
"hints": ["Hint 1"],
"examples": [{"title": "Example", "description": "...", "code": "..."}],
"quiz": {
"question": "...",
"type": "multiple_choice",
"options": [{"id": "a", "text": "Option A", "isCorrect": false}],
"explanation": "..."
}
}
]
}`, opts.Topic, opts.Query, opts.Difficulty, opts.Mode, opts.MaxSteps, opts.IncludeCode, opts.IncludeQuiz, langInstruction)
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
jsonStr := extractJSON(result)
var parsed struct {
Title string `json:"title"`
Description string `json:"description"`
LearningGoals []string `json:"learningGoals"`
EstimatedTimeMinutes int `json:"estimatedTimeMinutes"`
Steps []struct {
Title string `json:"title"`
Type string `json:"type"`
Content struct {
Text string `json:"text"`
Markdown string `json:"markdown"`
Code *struct {
Language string `json:"language"`
Code string `json:"code"`
Runnable bool `json:"runnable"`
} `json:"code"`
Formula string `json:"formula"`
} `json:"content"`
Hints []string `json:"hints"`
Examples []struct {
Title string `json:"title"`
Description string `json:"description"`
Code string `json:"code"`
} `json:"examples"`
Quiz *struct {
Question string `json:"question"`
Type string `json:"type"`
Options []struct {
ID string `json:"id"`
Text string `json:"text"`
IsCorrect bool `json:"isCorrect"`
} `json:"options"`
Explanation string `json:"explanation"`
} `json:"quiz"`
} `json:"steps"`
}
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
return g.createDefaultLesson(opts)
}
lesson := &StepByStepLesson{
ID: uuid.New().String(),
Title: parsed.Title,
Description: parsed.Description,
Topic: opts.Topic,
Difficulty: opts.Difficulty,
Mode: opts.Mode,
LearningGoals: parsed.LearningGoals,
EstimatedTime: parsed.EstimatedTimeMinutes,
Steps: make([]LearningStep, 0),
Progress: LessonProgress{
CurrentStep: 0,
CompletedSteps: []int{},
},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
for i, s := range parsed.Steps {
step := LearningStep{
ID: uuid.New().String(),
Number: i + 1,
Title: s.Title,
Type: StepType(s.Type),
Content: StepContent{
Text: s.Content.Text,
Markdown: s.Content.Markdown,
Formula: s.Content.Formula,
},
Hints: s.Hints,
Status: StatusAvailable,
}
if i > 0 {
step.Status = StatusLocked
}
if s.Content.Code != nil {
step.Content.Code = &CodeContent{
Language: s.Content.Code.Language,
Code: s.Content.Code.Code,
Runnable: s.Content.Code.Runnable,
Editable: true,
}
}
for _, ex := range s.Examples {
step.Examples = append(step.Examples, Example{
Title: ex.Title,
Description: ex.Description,
Code: ex.Code,
})
}
if s.Quiz != nil {
quiz := &QuizQuestion{
Question: s.Quiz.Question,
Type: s.Quiz.Type,
Explanation: s.Quiz.Explanation,
Points: 10,
}
for _, opt := range s.Quiz.Options {
quiz.Options = append(quiz.Options, Option{
ID: opt.ID,
Text: opt.Text,
IsCorrect: opt.IsCorrect,
})
}
step.Quiz = quiz
}
lesson.Steps = append(lesson.Steps, step)
}
lesson.Progress.MaxScore = len(lesson.Steps) * 10
return lesson, nil
}
func (g *LearningGenerator) createDefaultLesson(opts GenerateLessonOptions) (*StepByStepLesson, error) {
return &StepByStepLesson{
ID: uuid.New().String(),
Title: fmt.Sprintf("Learn: %s", opts.Topic),
Description: opts.Query,
Topic: opts.Topic,
Difficulty: opts.Difficulty,
Mode: opts.Mode,
LearningGoals: []string{"Understand the basics"},
EstimatedTime: 10,
Steps: []LearningStep{
{
ID: uuid.New().String(),
Number: 1,
Title: "Introduction",
Type: StepExplanation,
Content: StepContent{
Text: opts.Query,
Markdown: fmt.Sprintf("# %s\n\n%s", opts.Topic, opts.Query),
},
Status: StatusAvailable,
},
},
Progress: LessonProgress{
CurrentStep: 0,
CompletedSteps: []int{},
MaxScore: 10,
},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
}
func (g *LearningGenerator) GenerateExplanation(ctx context.Context, topic string, difficulty DifficultyLevel, locale string) (*LearningStep, error) {
langInstruction := ""
if locale == "ru" {
langInstruction = "Respond in Russian."
}
prompt := fmt.Sprintf(`Explain this topic step by step for a %s level learner.
Topic: %s
%s
Format your response with clear sections:
1. Start with a simple definition
2. Explain key concepts
3. Provide a real-world analogy
4. Give a concrete example
Use markdown formatting.`, difficulty, topic, langInstruction)
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
return &LearningStep{
ID: uuid.New().String(),
Number: 1,
Title: topic,
Type: StepExplanation,
Content: StepContent{
Markdown: result,
},
Status: StatusAvailable,
}, nil
}
func (g *LearningGenerator) GenerateQuiz(ctx context.Context, topic string, numQuestions int, difficulty DifficultyLevel, locale string) ([]QuizQuestion, error) {
langInstruction := ""
if locale == "ru" {
langInstruction = "Generate all questions and answers in Russian."
}
prompt := fmt.Sprintf(`Generate %d multiple choice quiz questions about: %s
Difficulty level: %s
%s
Respond in JSON format:
{
"questions": [
{
"question": "Question text",
"options": [
{"id": "a", "text": "Option A", "isCorrect": false},
{"id": "b", "text": "Option B", "isCorrect": true},
{"id": "c", "text": "Option C", "isCorrect": false},
{"id": "d", "text": "Option D", "isCorrect": false}
],
"explanation": "Why the correct answer is correct"
}
]
}`, numQuestions, topic, difficulty, langInstruction)
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
jsonStr := extractJSON(result)
var parsed struct {
Questions []struct {
Question string `json:"question"`
Options []struct {
ID string `json:"id"`
Text string `json:"text"`
IsCorrect bool `json:"isCorrect"`
} `json:"options"`
Explanation string `json:"explanation"`
} `json:"questions"`
}
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
return nil, err
}
questions := make([]QuizQuestion, 0)
for _, q := range parsed.Questions {
quiz := QuizQuestion{
Question: q.Question,
Type: "multiple_choice",
Explanation: q.Explanation,
Points: 10,
}
for _, opt := range q.Options {
quiz.Options = append(quiz.Options, Option{
ID: opt.ID,
Text: opt.Text,
IsCorrect: opt.IsCorrect,
})
}
questions = append(questions, quiz)
}
return questions, nil
}
func (g *LearningGenerator) GeneratePracticeExercise(ctx context.Context, topic, language string, difficulty DifficultyLevel, locale string) (*PracticeExercise, error) {
langInstruction := ""
if locale == "ru" {
langInstruction = "Write instructions and explanations in Russian."
}
prompt := fmt.Sprintf(`Create a coding practice exercise for: %s
Programming language: %s
Difficulty: %s
%s
Generate:
1. A clear problem statement
2. Step-by-step instructions
3. Starter code template
4. Solution code
5. Test cases
Respond in JSON:
{
"prompt": "Problem statement",
"instructions": "Step-by-step instructions",
"starter": "// Starter code",
"solution": "// Solution code",
"testCases": [
{"input": "test input", "expected": "expected output"}
],
"hints": ["Hint 1", "Hint 2"]
}`, topic, language, difficulty, langInstruction)
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
jsonStr := extractJSON(result)
var exercise PracticeExercise
if err := json.Unmarshal([]byte(jsonStr), &exercise); err != nil {
return &PracticeExercise{
Prompt: topic,
Instructions: "Practice this concept",
}, nil
}
return &exercise, nil
}
func (l *StepByStepLesson) CompleteStep(stepIndex int) {
if stepIndex < 0 || stepIndex >= len(l.Steps) {
return
}
l.Steps[stepIndex].Status = StatusCompleted
alreadyCompleted := false
for _, idx := range l.Progress.CompletedSteps {
if idx == stepIndex {
alreadyCompleted = true
break
}
}
if !alreadyCompleted {
l.Progress.CompletedSteps = append(l.Progress.CompletedSteps, stepIndex)
l.Progress.Score += 10
}
if stepIndex+1 < len(l.Steps) {
l.Steps[stepIndex+1].Status = StatusAvailable
l.Progress.CurrentStep = stepIndex + 1
}
if len(l.Progress.CompletedSteps) == len(l.Steps) {
l.Progress.Completed = true
now := time.Now()
l.Progress.CompletedAt = &now
}
l.UpdatedAt = time.Now()
l.Progress.LastAccessed = time.Now()
}
func (l *StepByStepLesson) SubmitQuizAnswer(stepIndex int, selectedOptions []string) (bool, string) {
if stepIndex < 0 || stepIndex >= len(l.Steps) {
return false, "Invalid step"
}
step := &l.Steps[stepIndex]
if step.Quiz == nil {
return false, "No quiz in this step"
}
correctCount := 0
totalCorrect := 0
for _, opt := range step.Quiz.Options {
if opt.IsCorrect {
totalCorrect++
}
}
for _, selected := range selectedOptions {
for _, opt := range step.Quiz.Options {
if opt.ID == selected && opt.IsCorrect {
correctCount++
}
}
}
isCorrect := correctCount == totalCorrect && len(selectedOptions) == totalCorrect
if isCorrect {
return true, step.Quiz.Explanation
}
return false, step.Quiz.Explanation
}
func extractJSON(text string) string {
re := regexp.MustCompile(`(?s)\{.*\}`)
match := re.FindString(text)
if match != "" {
return match
}
return "{}"
}
func (l *StepByStepLesson) ToJSON() ([]byte, error) {
return json.Marshal(l)
}
func ParseLesson(data []byte) (*StepByStepLesson, error) {
var lesson StepByStepLesson
if err := json.Unmarshal(data, &lesson); err != nil {
return nil, err
}
return &lesson, nil
}

View File

@@ -0,0 +1,182 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
type AnthropicClient struct {
baseClient
apiKey string
baseURL string
client *http.Client
}
func NewAnthropicClient(cfg ProviderConfig) (*AnthropicClient, error) {
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
return &AnthropicClient{
baseClient: baseClient{
providerID: cfg.ProviderID,
modelKey: cfg.ModelKey,
},
apiKey: cfg.APIKey,
baseURL: strings.TrimSuffix(baseURL, "/"),
client: &http.Client{},
}, nil
}
type anthropicRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
Tools []anthropicTool `json:"tools,omitempty"`
}
type anthropicMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type anthropicTool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema interface{} `json:"input_schema"`
}
type anthropicStreamEvent struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
Delta struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
} `json:"delta,omitempty"`
ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
} `json:"content_block,omitempty"`
}
func (c *AnthropicClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
var systemPrompt string
messages := make([]anthropicMessage, 0)
for _, m := range req.Messages {
if m.Role == RoleSystem {
systemPrompt = m.Content
continue
}
role := string(m.Role)
if role == "tool" {
role = "user"
}
messages = append(messages, anthropicMessage{
Role: role,
Content: m.Content,
})
}
maxTokens := req.Options.MaxTokens
if maxTokens == 0 {
maxTokens = 4096
}
anthropicReq := anthropicRequest{
Model: c.modelKey,
Messages: messages,
System: systemPrompt,
MaxTokens: maxTokens,
Stream: true,
}
if len(req.Tools) > 0 {
anthropicReq.Tools = make([]anthropicTool, len(req.Tools))
for i, t := range req.Tools {
anthropicReq.Tools[i] = anthropicTool{
Name: t.Name,
Description: t.Description,
InputSchema: t.Schema,
}
}
}
body, err := json.Marshal(anthropicReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/v1/messages", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", c.apiKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("anthropic API error: %d - %s", resp.StatusCode, string(body))
}
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return
}
var event anthropicStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta.Text != "" {
ch <- StreamChunk{ContentChunk: event.Delta.Text}
}
case "message_stop":
ch <- StreamChunk{FinishReason: "stop"}
return
}
}
}()
return ch, nil
}
func (c *AnthropicClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
ch, err := c.StreamText(ctx, req)
if err != nil {
return "", err
}
return readAllChunks(ch), nil
}

View File

@@ -0,0 +1,145 @@
package llm
import (
"context"
"fmt"
"io"
)
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleTool Role = "tool"
)
type Message struct {
Role Role `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
Images []ImageContent `json:"images,omitempty"`
}
type ImageContent struct {
Type string `json:"type"`
URL string `json:"url,omitempty"`
Data string `json:"data,omitempty"`
IsBase64 bool `json:"isBase64,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
Schema interface{} `json:"schema"`
}
type StreamOptions struct {
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
StopWords []string `json:"stop,omitempty"`
}
type StreamChunk struct {
ContentChunk string `json:"content_chunk,omitempty"`
ToolCallChunk []ToolCall `json:"tool_call_chunk,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type StreamRequest struct {
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
Options StreamOptions `json:"options,omitempty"`
}
type Client interface {
StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error)
GenerateText(ctx context.Context, req StreamRequest) (string, error)
GetProviderID() string
GetModelKey() string
}
type ProviderConfig struct {
ProviderID string `json:"providerId"`
ModelKey string `json:"key"`
APIKey string `json:"apiKey,omitempty"`
BaseURL string `json:"baseUrl,omitempty"`
AgentAccessID string `json:"agentAccessId,omitempty"`
}
func NewClient(cfg ProviderConfig) (Client, error) {
switch cfg.ProviderID {
case "timeweb":
return NewTimewebClient(TimewebConfig{
BaseURL: cfg.BaseURL,
AgentAccessID: cfg.AgentAccessID,
APIKey: cfg.APIKey,
ModelKey: cfg.ModelKey,
ProxySource: "gooseek",
})
case "openai":
return NewOpenAIClient(cfg)
case "anthropic":
return NewAnthropicClient(cfg)
case "gemini", "google":
return NewGeminiClient(cfg)
default:
return nil, fmt.Errorf("unknown provider: %s", cfg.ProviderID)
}
}
type baseClient struct {
providerID string
modelKey string
}
func (c *baseClient) GetProviderID() string {
return c.providerID
}
func (c *baseClient) GetModelKey() string {
return c.modelKey
}
func readAllChunks(ch <-chan StreamChunk) string {
var result string
for chunk := range ch {
result += chunk.ContentChunk
}
return result
}
type streamReader struct {
ch <-chan StreamChunk
buffer []byte
}
func (r *streamReader) Read(p []byte) (n int, err error) {
if len(r.buffer) > 0 {
n = copy(p, r.buffer)
r.buffer = r.buffer[n:]
return n, nil
}
chunk, ok := <-r.ch
if !ok {
return 0, io.EOF
}
data := []byte(chunk.ContentChunk)
n = copy(p, data)
if n < len(data) {
r.buffer = data[n:]
}
return n, nil
}

View File

@@ -0,0 +1,193 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
type GeminiClient struct {
baseClient
apiKey string
baseURL string
client *http.Client
}
func NewGeminiClient(cfg ProviderConfig) (*GeminiClient, error) {
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com/v1beta"
}
return &GeminiClient{
baseClient: baseClient{
providerID: cfg.ProviderID,
modelKey: cfg.ModelKey,
},
apiKey: cfg.APIKey,
baseURL: strings.TrimSuffix(baseURL, "/"),
client: &http.Client{},
}, nil
}
type geminiRequest struct {
Contents []geminiContent `json:"contents"`
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
GenerationConfig geminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []geminiTool `json:"tools,omitempty"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text,omitempty"`
}
type geminiGenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
type geminiTool struct {
FunctionDeclarations []geminiFunctionDecl `json:"functionDeclarations,omitempty"`
}
type geminiFunctionDecl struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters interface{} `json:"parameters"`
}
type geminiStreamResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason,omitempty"`
} `json:"candidates"`
}
func (c *GeminiClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
contents := make([]geminiContent, 0)
var systemInstruction *geminiContent
for _, m := range req.Messages {
if m.Role == RoleSystem {
systemInstruction = &geminiContent{
Parts: []geminiPart{{Text: m.Content}},
}
continue
}
role := "user"
if m.Role == RoleAssistant {
role = "model"
}
contents = append(contents, geminiContent{
Role: role,
Parts: []geminiPart{{Text: m.Content}},
})
}
geminiReq := geminiRequest{
Contents: contents,
SystemInstruction: systemInstruction,
GenerationConfig: geminiGenerationConfig{
MaxOutputTokens: req.Options.MaxTokens,
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
},
}
if len(req.Tools) > 0 {
decls := make([]geminiFunctionDecl, len(req.Tools))
for i, t := range req.Tools {
decls[i] = geminiFunctionDecl{
Name: t.Name,
Description: t.Description,
Parameters: t.Schema,
}
}
geminiReq.Tools = []geminiTool{{FunctionDeclarations: decls}}
}
body, err := json.Marshal(geminiReq)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s&alt=sse",
c.baseURL, c.modelKey, c.apiKey)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("gemini API error: %d - %s", resp.StatusCode, string(body))
}
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var response geminiStreamResponse
if err := json.Unmarshal([]byte(data), &response); err != nil {
continue
}
if len(response.Candidates) > 0 {
candidate := response.Candidates[0]
for _, part := range candidate.Content.Parts {
if part.Text != "" {
ch <- StreamChunk{ContentChunk: part.Text}
}
}
if candidate.FinishReason != "" {
ch <- StreamChunk{FinishReason: candidate.FinishReason}
}
}
}
}()
return ch, nil
}
func (c *GeminiClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
ch, err := c.StreamText(ctx, req)
if err != nil {
return "", err
}
return readAllChunks(ch), nil
}

View File

@@ -0,0 +1,166 @@
package llm
import (
"context"
"encoding/json"
"errors"
"io"
"github.com/sashabaranov/go-openai"
)
type OpenAIClient struct {
baseClient
client *openai.Client
}
func NewOpenAIClient(cfg ProviderConfig) (*OpenAIClient, error) {
config := openai.DefaultConfig(cfg.APIKey)
if cfg.BaseURL != "" {
config.BaseURL = cfg.BaseURL
}
return &OpenAIClient{
baseClient: baseClient{
providerID: cfg.ProviderID,
modelKey: cfg.ModelKey,
},
client: openai.NewClientWithConfig(config),
}, nil
}
func (c *OpenAIClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
messages := make([]openai.ChatCompletionMessage, 0, len(req.Messages))
for _, m := range req.Messages {
msg := openai.ChatCompletionMessage{
Role: string(m.Role),
Content: m.Content,
}
if m.Name != "" {
msg.Name = m.Name
}
if m.ToolCallID != "" {
msg.ToolCallID = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls))
for i, tc := range m.ToolCalls {
args, _ := json.Marshal(tc.Arguments)
msg.ToolCalls[i] = openai.ToolCall{
ID: tc.ID,
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{
Name: tc.Name,
Arguments: string(args),
},
}
}
}
messages = append(messages, msg)
}
chatReq := openai.ChatCompletionRequest{
Model: c.modelKey,
Messages: messages,
Stream: true,
}
if req.Options.MaxTokens > 0 {
chatReq.MaxTokens = req.Options.MaxTokens
}
if req.Options.Temperature > 0 {
chatReq.Temperature = float32(req.Options.Temperature)
}
if req.Options.TopP > 0 {
chatReq.TopP = float32(req.Options.TopP)
}
if len(req.Tools) > 0 {
chatReq.Tools = make([]openai.Tool, len(req.Tools))
for i, t := range req.Tools {
chatReq.Tools[i] = openai.Tool{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: t.Name,
Description: t.Description,
Parameters: t.Schema,
},
}
}
}
stream, err := c.client.CreateChatCompletionStream(ctx, chatReq)
if err != nil {
return nil, err
}
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
defer stream.Close()
toolCalls := make(map[int]*ToolCall)
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
if len(toolCalls) > 0 {
calls := make([]ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
calls = append(calls, *tc)
}
ch <- StreamChunk{ToolCallChunk: calls}
}
return
}
if err != nil {
return
}
if len(response.Choices) == 0 {
continue
}
delta := response.Choices[0].Delta
if delta.Content != "" {
ch <- StreamChunk{ContentChunk: delta.Content}
}
for _, tc := range delta.ToolCalls {
idx := *tc.Index
if _, ok := toolCalls[idx]; !ok {
toolCalls[idx] = &ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: make(map[string]interface{}),
}
}
if tc.Function.Arguments != "" {
existing := toolCalls[idx]
var args map[string]interface{}
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil {
for k, v := range args {
existing.Arguments[k] = v
}
}
}
}
if response.Choices[0].FinishReason != "" {
ch <- StreamChunk{FinishReason: string(response.Choices[0].FinishReason)}
}
}
}()
return ch, nil
}
func (c *OpenAIClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
ch, err := c.StreamText(ctx, req)
if err != nil {
return "", err
}
return readAllChunks(ch), nil
}

View File

@@ -0,0 +1,229 @@
package llm
import (
"errors"
"sort"
"sync"
)
type ModelCapability string
const (
CapReasoning ModelCapability = "reasoning"
CapCoding ModelCapability = "coding"
CapSearch ModelCapability = "search"
CapCreative ModelCapability = "creative"
CapFast ModelCapability = "fast"
CapLongContext ModelCapability = "long_context"
CapVision ModelCapability = "vision"
CapMath ModelCapability = "math"
CapVideo ModelCapability = "video"
CapImage ModelCapability = "image"
)
type ModelSpec struct {
ID string
Provider string
Model string
Capabilities []ModelCapability
CostPer1K float64
MaxContext int
Priority int
MaxTokens int
Description string
}
func (m ModelSpec) HasCapability(cap ModelCapability) bool {
for _, c := range m.Capabilities {
if c == cap {
return true
}
}
return false
}
type ModelRegistry struct {
models map[string]ModelSpec
clients map[string]Client
mu sync.RWMutex
}
func NewModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]ModelSpec),
clients: make(map[string]Client),
}
}
func (r *ModelRegistry) Register(spec ModelSpec, client Client) {
r.mu.Lock()
defer r.mu.Unlock()
r.models[spec.ID] = spec
r.clients[spec.ID] = client
}
func (r *ModelRegistry) Unregister(id string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.models, id)
delete(r.clients, id)
}
func (r *ModelRegistry) GetByID(id string) (Client, ModelSpec, error) {
r.mu.RLock()
defer r.mu.RUnlock()
spec, ok := r.models[id]
if !ok {
return nil, ModelSpec{}, errors.New("model not found: " + id)
}
client, ok := r.clients[id]
if !ok {
return nil, ModelSpec{}, errors.New("client not found: " + id)
}
return client, spec, nil
}
func (r *ModelRegistry) GetBest(cap ModelCapability) (Client, ModelSpec, error) {
r.mu.RLock()
defer r.mu.RUnlock()
var candidates []ModelSpec
for _, spec := range r.models {
if spec.HasCapability(cap) {
candidates = append(candidates, spec)
}
}
if len(candidates) == 0 {
return nil, ModelSpec{}, errors.New("no model found with capability: " + string(cap))
}
sort.Slice(candidates, func(i, j int) bool {
if candidates[i].Priority != candidates[j].Priority {
return candidates[i].Priority < candidates[j].Priority
}
return candidates[i].CostPer1K < candidates[j].CostPer1K
})
best := candidates[0]
client := r.clients[best.ID]
return client, best, nil
}
func (r *ModelRegistry) GetAllWithCapability(cap ModelCapability) []ModelSpec {
r.mu.RLock()
defer r.mu.RUnlock()
var result []ModelSpec
for _, spec := range r.models {
if spec.HasCapability(cap) {
result = append(result, spec)
}
}
sort.Slice(result, func(i, j int) bool {
return result[i].Priority < result[j].Priority
})
return result
}
func (r *ModelRegistry) GetAll() []ModelSpec {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]ModelSpec, 0, len(r.models))
for _, spec := range r.models {
result = append(result, spec)
}
return result
}
func (r *ModelRegistry) GetClient(id string) (Client, error) {
r.mu.RLock()
defer r.mu.RUnlock()
client, ok := r.clients[id]
if !ok {
return nil, errors.New("client not found: " + id)
}
return client, nil
}
func (r *ModelRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.models)
}
var DefaultModels = []ModelSpec{
{
ID: "gpt-4o",
Provider: "openai",
Model: "gpt-4o",
Capabilities: []ModelCapability{CapSearch, CapFast, CapVision, CapCoding, CapCreative},
CostPer1K: 0.005,
MaxContext: 128000,
MaxTokens: 16384,
Priority: 1,
Description: "GPT-4o: fast multimodal model with search",
},
{
ID: "gpt-4o-mini",
Provider: "openai",
Model: "gpt-4o-mini",
Capabilities: []ModelCapability{CapFast, CapCoding},
CostPer1K: 0.00015,
MaxContext: 128000,
MaxTokens: 16384,
Priority: 2,
Description: "GPT-4o Mini: cost-effective for simple tasks",
},
{
ID: "claude-3-opus",
Provider: "anthropic",
Model: "claude-3-opus-20240229",
Capabilities: []ModelCapability{CapReasoning, CapCoding, CapCreative, CapLongContext},
CostPer1K: 0.015,
MaxContext: 200000,
MaxTokens: 4096,
Priority: 1,
Description: "Claude 3 Opus: best for complex reasoning and coding",
},
{
ID: "claude-3-sonnet",
Provider: "anthropic",
Model: "claude-3-5-sonnet-20241022",
Capabilities: []ModelCapability{CapCoding, CapCreative, CapFast},
CostPer1K: 0.003,
MaxContext: 200000,
MaxTokens: 8192,
Priority: 1,
Description: "Claude 3.5 Sonnet: balanced speed and quality",
},
{
ID: "gemini-1.5-pro",
Provider: "gemini",
Model: "gemini-1.5-pro",
Capabilities: []ModelCapability{CapLongContext, CapSearch, CapVision, CapMath},
CostPer1K: 0.00125,
MaxContext: 2000000,
MaxTokens: 8192,
Priority: 1,
Description: "Gemini 1.5 Pro: best for long context and research",
},
{
ID: "gemini-1.5-flash",
Provider: "gemini",
Model: "gemini-1.5-flash",
Capabilities: []ModelCapability{CapFast, CapVision},
CostPer1K: 0.000075,
MaxContext: 1000000,
MaxTokens: 8192,
Priority: 2,
Description: "Gemini 1.5 Flash: fastest for lightweight tasks",
},
}

View File

@@ -0,0 +1,402 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type TimewebClient struct {
baseClient
httpClient *http.Client
baseURL string
agentAccessID string
apiKey string
proxySource string
}
type TimewebConfig struct {
ProviderID string
ModelKey string
BaseURL string
AgentAccessID string
APIKey string
ProxySource string
}
func NewTimewebClient(cfg TimewebConfig) (*TimewebClient, error) {
if cfg.AgentAccessID == "" {
return nil, errors.New("agent_access_id is required for Timeweb")
}
if cfg.APIKey == "" {
return nil, errors.New("api_key is required for Timeweb")
}
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://api.timeweb.cloud"
}
proxySource := cfg.ProxySource
if proxySource == "" {
proxySource = "gooseek"
}
return &TimewebClient{
baseClient: baseClient{
providerID: cfg.ProviderID,
modelKey: cfg.ModelKey,
},
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
baseURL: baseURL,
agentAccessID: cfg.AgentAccessID,
apiKey: cfg.APIKey,
proxySource: proxySource,
}, nil
}
type timewebChatRequest struct {
Model string `json:"model,omitempty"`
Messages []timewebMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []timewebTool `json:"tools,omitempty"`
Stop []string `json:"stop,omitempty"`
}
type timewebMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
Name string `json:"name,omitempty"`
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type timewebTool struct {
Type string `json:"type"`
Function timewebFunction `json:"function"`
}
type timewebFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters interface{} `json:"parameters"`
}
type timewebToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type timewebChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type timewebStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices"`
}
func (c *TimewebClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
messages := make([]timewebMessage, 0, len(req.Messages))
for _, m := range req.Messages {
msg := timewebMessage{
Role: string(m.Role),
Content: m.Content,
}
if m.Name != "" {
msg.Name = m.Name
}
if m.ToolCallID != "" {
msg.ToolCallID = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg.ToolCalls = make([]timewebToolCall, len(m.ToolCalls))
for i, tc := range m.ToolCalls {
args, _ := json.Marshal(tc.Arguments)
msg.ToolCalls[i] = timewebToolCall{
ID: tc.ID,
Type: "function",
}
msg.ToolCalls[i].Function.Name = tc.Name
msg.ToolCalls[i].Function.Arguments = string(args)
}
}
messages = append(messages, msg)
}
chatReq := timewebChatRequest{
Model: c.modelKey,
Messages: messages,
Stream: true,
}
if req.Options.MaxTokens > 0 {
chatReq.MaxTokens = req.Options.MaxTokens
}
if req.Options.Temperature > 0 {
chatReq.Temperature = req.Options.Temperature
}
if req.Options.TopP > 0 {
chatReq.TopP = req.Options.TopP
}
if len(req.Options.StopWords) > 0 {
chatReq.Stop = req.Options.StopWords
}
if len(req.Tools) > 0 {
chatReq.Tools = make([]timewebTool, len(req.Tools))
for i, t := range req.Tools {
chatReq.Tools[i] = timewebTool{
Type: "function",
Function: timewebFunction{
Name: t.Name,
Description: t.Description,
Parameters: t.Schema,
},
}
}
}
body, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/api/v1/cloud-ai/agents/%s/v1/chat/completions", c.baseURL, c.agentAccessID)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
httpReq.Header.Set("x-proxy-source", c.proxySource)
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Timeweb API error: status %d, body: %s", resp.StatusCode, string(body))
}
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
defer resp.Body.Close()
toolCalls := make(map[int]*ToolCall)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
return
}
if len(toolCalls) > 0 {
calls := make([]ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
calls = append(calls, *tc)
}
ch <- StreamChunk{ToolCallChunk: calls}
}
return
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
if len(toolCalls) > 0 {
calls := make([]ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
calls = append(calls, *tc)
}
ch <- StreamChunk{ToolCallChunk: calls}
}
return
}
var streamResp timewebStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) == 0 {
continue
}
delta := streamResp.Choices[0].Delta
if delta.Content != "" {
ch <- StreamChunk{ContentChunk: delta.Content}
}
for _, tc := range delta.ToolCalls {
idx := 0
if _, ok := toolCalls[idx]; !ok {
toolCalls[idx] = &ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: make(map[string]interface{}),
}
}
if tc.Function.Arguments != "" {
existing := toolCalls[idx]
var args map[string]interface{}
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil {
for k, v := range args {
existing.Arguments[k] = v
}
}
}
}
if streamResp.Choices[0].FinishReason != "" {
ch <- StreamChunk{FinishReason: streamResp.Choices[0].FinishReason}
}
}
}()
return ch, nil
}
func (c *TimewebClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
messages := make([]timewebMessage, 0, len(req.Messages))
for _, m := range req.Messages {
msg := timewebMessage{
Role: string(m.Role),
Content: m.Content,
}
if m.Name != "" {
msg.Name = m.Name
}
if m.ToolCallID != "" {
msg.ToolCallID = m.ToolCallID
}
messages = append(messages, msg)
}
chatReq := timewebChatRequest{
Model: c.modelKey,
Messages: messages,
Stream: false,
}
if req.Options.MaxTokens > 0 {
chatReq.MaxTokens = req.Options.MaxTokens
}
if req.Options.Temperature > 0 {
chatReq.Temperature = req.Options.Temperature
}
if req.Options.TopP > 0 {
chatReq.TopP = req.Options.TopP
}
if len(req.Tools) > 0 {
chatReq.Tools = make([]timewebTool, len(req.Tools))
for i, t := range req.Tools {
chatReq.Tools[i] = timewebTool{
Type: "function",
Function: timewebFunction{
Name: t.Name,
Description: t.Description,
Parameters: t.Schema,
},
}
}
}
body, err := json.Marshal(chatReq)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/api/v1/cloud-ai/agents/%s/v1/chat/completions", c.baseURL, c.agentAccessID)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
httpReq.Header.Set("x-proxy-source", c.proxySource)
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("Timeweb API error: status %d, body: %s", resp.StatusCode, string(body))
}
var chatResp timewebChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", errors.New("no choices in response")
}
return chatResp.Choices[0].Message.Content, nil
}

View File

@@ -0,0 +1,318 @@
package pages
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/gooseek/backend/internal/types"
"github.com/google/uuid"
)
type Page struct {
ID string `json:"id"`
UserID string `json:"userId"`
ThreadID string `json:"threadId,omitempty"`
Title string `json:"title"`
Subtitle string `json:"subtitle,omitempty"`
Sections []PageSection `json:"sections"`
Sources []PageSource `json:"sources"`
Thumbnail string `json:"thumbnail,omitempty"`
IsPublic bool `json:"isPublic"`
ShareID string `json:"shareId,omitempty"`
ViewCount int `json:"viewCount"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type PageSection struct {
ID string `json:"id"`
Type string `json:"type"`
Title string `json:"title,omitempty"`
Content string `json:"content"`
ImageURL string `json:"imageUrl,omitempty"`
Order int `json:"order"`
}
type PageSource struct {
Index int `json:"index"`
URL string `json:"url"`
Title string `json:"title"`
Domain string `json:"domain"`
Favicon string `json:"favicon,omitempty"`
}
type PageGeneratorConfig struct {
LLMClient llm.Client
Locale string
Style string
Audience string
}
type PageGenerator struct {
cfg PageGeneratorConfig
}
func NewPageGenerator(cfg PageGeneratorConfig) *PageGenerator {
return &PageGenerator{cfg: cfg}
}
func (g *PageGenerator) GenerateFromThread(ctx context.Context, query string, answer string, sources []types.Chunk) (*Page, error) {
structurePrompt := g.buildStructurePrompt(query, answer, sources)
structure, err := g.cfg.LLMClient.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{
{Role: "user", Content: structurePrompt},
},
})
if err != nil {
return nil, fmt.Errorf("failed to generate structure: %w", err)
}
page := g.parseStructure(structure)
page.ID = uuid.New().String()
page.CreatedAt = time.Now()
page.UpdatedAt = time.Now()
for i, src := range sources {
if i >= 20 {
break
}
url := src.Metadata["url"]
title := src.Metadata["title"]
page.Sources = append(page.Sources, PageSource{
Index: i + 1,
URL: url,
Title: title,
Domain: extractDomain(url),
})
}
return page, nil
}
func (g *PageGenerator) buildStructurePrompt(query, answer string, sources []types.Chunk) string {
var sourcesText strings.Builder
for i, s := range sources {
if i >= 15 {
break
}
sourcesText.WriteString(fmt.Sprintf("[%d] %s\n%s\n\n", i+1, s.Metadata["title"], truncate(s.Content, 300)))
}
langInstr := ""
if g.cfg.Locale == "ru" {
langInstr = "Write in Russian."
}
style := g.cfg.Style
if style == "" {
style = "informative"
}
audience := g.cfg.Audience
if audience == "" {
audience = "general"
}
return fmt.Sprintf(`Create a well-structured article from this research.
Topic: %s
Research findings:
%s
Sources:
%s
%s
Style: %s
Target audience: %s
Generate the article in this exact format:
TITLE: [compelling title]
SUBTITLE: [brief subtitle]
SECTION: Introduction
[2-3 paragraphs introducing the topic]
SECTION: [Topic Name 1]
[detailed content with citations [1], [2], etc.]
SECTION: [Topic Name 2]
[detailed content with citations]
SECTION: [Topic Name 3]
[detailed content with citations]
SECTION: Conclusion
[summary and key takeaways]
SECTION: Key Points
- [bullet point 1]
- [bullet point 2]
- [bullet point 3]
Requirements:
- Use citations [1], [2], etc. throughout
- Make it comprehensive but readable
- Include specific facts and data
- Keep sections focused and well-organized`, query, truncate(answer, 2000), sourcesText.String(), langInstr, style, audience)
}
func (g *PageGenerator) parseStructure(text string) *Page {
page := &Page{
Sections: make([]PageSection, 0),
}
lines := strings.Split(text, "\n")
var currentSection *PageSection
var contentBuilder strings.Builder
order := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "TITLE:") {
page.Title = strings.TrimSpace(strings.TrimPrefix(line, "TITLE:"))
continue
}
if strings.HasPrefix(line, "SUBTITLE:") {
page.Subtitle = strings.TrimSpace(strings.TrimPrefix(line, "SUBTITLE:"))
continue
}
if strings.HasPrefix(line, "SECTION:") {
if currentSection != nil {
currentSection.Content = strings.TrimSpace(contentBuilder.String())
page.Sections = append(page.Sections, *currentSection)
contentBuilder.Reset()
}
order++
currentSection = &PageSection{
ID: uuid.New().String(),
Type: "text",
Title: strings.TrimSpace(strings.TrimPrefix(line, "SECTION:")),
Order: order,
}
continue
}
if currentSection != nil {
contentBuilder.WriteString(line)
contentBuilder.WriteString("\n")
}
}
if currentSection != nil {
currentSection.Content = strings.TrimSpace(contentBuilder.String())
page.Sections = append(page.Sections, *currentSection)
}
return page
}
func (g *PageGenerator) ExportToMarkdown(page *Page) string {
var md strings.Builder
md.WriteString("# " + page.Title + "\n\n")
if page.Subtitle != "" {
md.WriteString("*" + page.Subtitle + "*\n\n")
}
for _, section := range page.Sections {
md.WriteString("## " + section.Title + "\n\n")
md.WriteString(section.Content + "\n\n")
}
md.WriteString("---\n\n## Sources\n\n")
for _, src := range page.Sources {
md.WriteString(fmt.Sprintf("%d. [%s](%s)\n", src.Index, src.Title, src.URL))
}
return md.String()
}
func (g *PageGenerator) ExportToHTML(page *Page) string {
var html strings.Builder
html.WriteString("<!DOCTYPE html>\n<html>\n<head>\n")
html.WriteString(fmt.Sprintf("<title>%s</title>\n", page.Title))
html.WriteString("<style>\n")
html.WriteString(`body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; line-height: 1.6; }
h1 { color: #1a1a1a; border-bottom: 2px solid #007bff; padding-bottom: 10px; }
h2 { color: #333; margin-top: 30px; }
.subtitle { color: #666; font-style: italic; margin-bottom: 30px; }
.sources { background: #f5f5f5; padding: 20px; border-radius: 8px; margin-top: 40px; }
.sources a { color: #007bff; text-decoration: none; }
.sources a:hover { text-decoration: underline; }
`)
html.WriteString("</style>\n</head>\n<body>\n")
html.WriteString(fmt.Sprintf("<h1>%s</h1>\n", page.Title))
if page.Subtitle != "" {
html.WriteString(fmt.Sprintf("<p class=\"subtitle\">%s</p>\n", page.Subtitle))
}
for _, section := range page.Sections {
html.WriteString(fmt.Sprintf("<h2>%s</h2>\n", section.Title))
paragraphs := strings.Split(section.Content, "\n\n")
for _, p := range paragraphs {
p = strings.TrimSpace(p)
if p != "" {
if strings.HasPrefix(p, "- ") {
html.WriteString("<ul>\n")
for _, item := range strings.Split(p, "\n") {
item = strings.TrimPrefix(item, "- ")
html.WriteString(fmt.Sprintf("<li>%s</li>\n", item))
}
html.WriteString("</ul>\n")
} else {
html.WriteString(fmt.Sprintf("<p>%s</p>\n", p))
}
}
}
}
html.WriteString("<div class=\"sources\">\n<h3>Sources</h3>\n<ol>\n")
for _, src := range page.Sources {
html.WriteString(fmt.Sprintf("<li><a href=\"%s\" target=\"_blank\">%s</a> (%s)</li>\n", src.URL, src.Title, src.Domain))
}
html.WriteString("</ol>\n</div>\n")
html.WriteString("</body>\n</html>")
return html.String()
}
func (g *PageGenerator) ToJSON(page *Page) (string, error) {
data, err := json.MarshalIndent(page, "", " ")
if err != nil {
return "", err
}
return string(data), nil
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func extractDomain(url string) string {
url = strings.TrimPrefix(url, "https://")
url = strings.TrimPrefix(url, "http://")
url = strings.TrimPrefix(url, "www.")
if idx := strings.Index(url, "/"); idx > 0 {
return url[:idx]
}
return url
}

View File

@@ -0,0 +1,507 @@
package podcast
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/gooseek/backend/internal/llm"
"github.com/google/uuid"
)
type PodcastType string
const (
PodcastDaily PodcastType = "daily"
PodcastWeekly PodcastType = "weekly"
PodcastTopicDeep PodcastType = "topic_deep"
PodcastBreaking PodcastType = "breaking"
)
type VoiceStyle string
const (
VoiceNeutral VoiceStyle = "neutral"
VoiceEnthusiastic VoiceStyle = "enthusiastic"
VoiceProfessional VoiceStyle = "professional"
VoiceCasual VoiceStyle = "casual"
VoiceStorytelling VoiceStyle = "storytelling"
)
type Podcast struct {
ID string `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
Type PodcastType `json:"type"`
Date time.Time `json:"date"`
Duration int `json:"durationSeconds"`
AudioURL string `json:"audioUrl,omitempty"`
Transcript string `json:"transcript"`
Segments []PodcastSegment `json:"segments"`
Topics []string `json:"topics"`
Sources []Source `json:"sources"`
Thumbnail string `json:"thumbnail,omitempty"`
Status PodcastStatus `json:"status"`
GeneratedAt time.Time `json:"generatedAt"`
PublishedAt *time.Time `json:"publishedAt,omitempty"`
Locale string `json:"locale"`
VoiceConfig VoiceConfig `json:"voiceConfig"`
}
type PodcastStatus string
const (
StatusDraft PodcastStatus = "draft"
StatusGenerating PodcastStatus = "generating"
StatusReady PodcastStatus = "ready"
StatusPublished PodcastStatus = "published"
StatusFailed PodcastStatus = "failed"
)
type PodcastSegment struct {
ID string `json:"id"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Duration int `json:"durationSeconds"`
StartTime int `json:"startTime"`
EndTime int `json:"endTime"`
Sources []Source `json:"sources,omitempty"`
Highlights []string `json:"highlights,omitempty"`
}
type Source struct {
Title string `json:"title"`
URL string `json:"url"`
Publisher string `json:"publisher"`
Date string `json:"date,omitempty"`
}
type VoiceConfig struct {
Provider string `json:"provider"`
VoiceID string `json:"voiceId"`
Style VoiceStyle `json:"style"`
Speed float64 `json:"speed"`
Pitch float64 `json:"pitch"`
Language string `json:"language"`
}
type PodcastGenerator struct {
llm llm.Client
ttsClient TTSClient
httpClient *http.Client
config GeneratorConfig
}
type GeneratorConfig struct {
DefaultDuration int
MaxDuration int
DefaultVoice VoiceConfig
OutputDir string
}
type TTSClient interface {
GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error)
}
func NewPodcastGenerator(llmClient llm.Client, ttsClient TTSClient, cfg GeneratorConfig) *PodcastGenerator {
if cfg.DefaultDuration == 0 {
cfg.DefaultDuration = 300
}
if cfg.MaxDuration == 0 {
cfg.MaxDuration = 1800
}
if cfg.DefaultVoice.Provider == "" {
cfg.DefaultVoice = VoiceConfig{
Provider: "elevenlabs",
VoiceID: "21m00Tcm4TlvDq8ikWAM",
Style: VoiceNeutral,
Speed: 1.0,
Pitch: 1.0,
Language: "ru",
}
}
return &PodcastGenerator{
llm: llmClient,
ttsClient: ttsClient,
httpClient: &http.Client{Timeout: 60 * time.Second},
config: cfg,
}
}
type GenerateOptions struct {
Type PodcastType
Topics []string
NewsItems []NewsItem
Date time.Time
Duration int
Locale string
VoiceConfig *VoiceConfig
IncludeIntro bool
IncludeOutro bool
PersonalizeFor string
}
type NewsItem struct {
Title string `json:"title"`
Summary string `json:"summary"`
URL string `json:"url"`
Source string `json:"source"`
PublishedAt string `json:"publishedAt"`
Topics []string `json:"topics"`
Importance int `json:"importance"`
}
func (g *PodcastGenerator) GenerateDailyPodcast(ctx context.Context, opts GenerateOptions) (*Podcast, error) {
if opts.Date.IsZero() {
opts.Date = time.Now()
}
if opts.Duration == 0 {
opts.Duration = g.config.DefaultDuration
}
if opts.Locale == "" {
opts.Locale = "ru"
}
script, err := g.generateScript(ctx, opts)
if err != nil {
return nil, fmt.Errorf("failed to generate script: %w", err)
}
podcast := &Podcast{
ID: uuid.New().String(),
Title: script.Title,
Description: script.Description,
Type: opts.Type,
Date: opts.Date,
Duration: opts.Duration,
Transcript: script.FullText,
Segments: script.Segments,
Topics: opts.Topics,
Sources: script.Sources,
Status: StatusDraft,
GeneratedAt: time.Now(),
Locale: opts.Locale,
VoiceConfig: g.config.DefaultVoice,
}
if opts.VoiceConfig != nil {
podcast.VoiceConfig = *opts.VoiceConfig
}
return podcast, nil
}
type PodcastScript struct {
Title string
Description string
FullText string
Segments []PodcastSegment
Sources []Source
}
func (g *PodcastGenerator) generateScript(ctx context.Context, opts GenerateOptions) (*PodcastScript, error) {
locale := opts.Locale
langInstruction := ""
if locale == "ru" {
langInstruction = "Generate the entire script in Russian language. Use natural Russian speech patterns."
}
newsJSON, _ := json.Marshal(opts.NewsItems)
prompt := fmt.Sprintf(`Create a podcast script for a daily news digest.
Date: %s
Duration target: %d seconds (approximately %d minutes)
Topics: %v
%s
News items to cover:
%s
Create an engaging podcast script with these requirements:
1. Start with a catchy introduction greeting the audience
2. Cover the most important news first
3. Transition smoothly between stories
4. Add brief analysis or context where appropriate
5. End with a summary and sign-off
The script should sound natural when read aloud - use conversational language, not formal news anchor style.
Respond in JSON format:
{
"title": "Podcast title for this episode",
"description": "Brief episode description",
"segments": [
{
"type": "intro|news|analysis|transition|outro",
"title": "Segment title",
"content": "Full text to be spoken",
"highlights": ["Key point 1", "Key point 2"],
"sources": [{"title": "Source title", "url": "url", "publisher": "publisher"}]
}
]
}`, opts.Date.Format("2006-01-02"), opts.Duration, opts.Duration/60, opts.Topics, langInstruction, string(newsJSON))
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
Messages: []llm.Message{{Role: "user", Content: prompt}},
})
if err != nil {
return nil, err
}
jsonStr := extractJSON(result)
var parsed struct {
Title string `json:"title"`
Description string `json:"description"`
Segments []struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Highlights []string `json:"highlights"`
Sources []struct {
Title string `json:"title"`
URL string `json:"url"`
Publisher string `json:"publisher"`
} `json:"sources"`
} `json:"segments"`
}
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
return g.generateDefaultScript(opts)
}
script := &PodcastScript{
Title: parsed.Title,
Description: parsed.Description,
Segments: make([]PodcastSegment, 0),
Sources: make([]Source, 0),
}
var fullTextBuilder strings.Builder
currentTime := 0
avgWordsPerSecond := 2.5
for i, seg := range parsed.Segments {
wordCount := len(strings.Fields(seg.Content))
segDuration := int(float64(wordCount) / avgWordsPerSecond)
if segDuration < 10 {
segDuration = 10
}
segment := PodcastSegment{
ID: uuid.New().String(),
Type: seg.Type,
Title: seg.Title,
Content: seg.Content,
Duration: segDuration,
StartTime: currentTime,
EndTime: currentTime + segDuration,
Highlights: seg.Highlights,
}
for _, src := range seg.Sources {
source := Source{
Title: src.Title,
URL: src.URL,
Publisher: src.Publisher,
}
segment.Sources = append(segment.Sources, source)
script.Sources = append(script.Sources, source)
}
script.Segments = append(script.Segments, segment)
fullTextBuilder.WriteString(seg.Content)
if i < len(parsed.Segments)-1 {
fullTextBuilder.WriteString("\n\n")
}
currentTime += segDuration
}
script.FullText = fullTextBuilder.String()
return script, nil
}
func (g *PodcastGenerator) generateDefaultScript(opts GenerateOptions) (*PodcastScript, error) {
date := opts.Date.Format("2 January 2006")
intro := fmt.Sprintf("Добрый день! С вами GooSeek Daily — ваш ежедневный подкаст с главными новостями. Сегодня %s, и вот что происходит в мире.", date)
var newsContent strings.Builder
for i, news := range opts.NewsItems {
if i > 0 {
newsContent.WriteString("\n\n")
}
newsContent.WriteString(fmt.Sprintf("%s. %s", news.Title, news.Summary))
}
outro := "На этом всё на сегодня. Спасибо, что слушаете GooSeek Daily! Подписывайтесь на наш подкаст и до встречи завтра."
return &PodcastScript{
Title: fmt.Sprintf("GooSeek Daily — %s", date),
Description: "Ежедневный подкаст с главными новостями",
FullText: fmt.Sprintf("%s\n\n%s\n\n%s", intro, newsContent.String(), outro),
Segments: []PodcastSegment{
{ID: uuid.New().String(), Type: "intro", Title: "Вступление", Content: intro, Duration: 15},
{ID: uuid.New().String(), Type: "news", Title: "Новости", Content: newsContent.String(), Duration: opts.Duration - 30},
{ID: uuid.New().String(), Type: "outro", Title: "Завершение", Content: outro, Duration: 15},
},
}, nil
}
func (g *PodcastGenerator) GenerateAudio(ctx context.Context, podcast *Podcast) ([]byte, error) {
if g.ttsClient == nil {
return nil, fmt.Errorf("TTS client not configured")
}
podcast.Status = StatusGenerating
audioData, err := g.ttsClient.GenerateSpeech(ctx, podcast.Transcript, podcast.VoiceConfig)
if err != nil {
podcast.Status = StatusFailed
return nil, fmt.Errorf("failed to generate audio: %w", err)
}
podcast.Status = StatusReady
return audioData, nil
}
func (g *PodcastGenerator) GenerateWeeklySummary(ctx context.Context, weeklyNews []NewsItem, locale string) (*Podcast, error) {
return g.GenerateDailyPodcast(ctx, GenerateOptions{
Type: PodcastWeekly,
NewsItems: weeklyNews,
Duration: 900,
Locale: locale,
IncludeIntro: true,
IncludeOutro: true,
})
}
func (g *PodcastGenerator) GenerateTopicDeepDive(ctx context.Context, topic string, articles []NewsItem, locale string) (*Podcast, error) {
return g.GenerateDailyPodcast(ctx, GenerateOptions{
Type: PodcastTopicDeep,
Topics: []string{topic},
NewsItems: articles,
Duration: 600,
Locale: locale,
IncludeIntro: true,
IncludeOutro: true,
})
}
func extractJSON(text string) string {
start := strings.Index(text, "{")
if start == -1 {
return "{}"
}
depth := 0
for i := start; i < len(text); i++ {
if text[i] == '{' {
depth++
} else if text[i] == '}' {
depth--
if depth == 0 {
return text[start : i+1]
}
}
}
return "{}"
}
func (p *Podcast) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
func ParsePodcast(data []byte) (*Podcast, error) {
var podcast Podcast
if err := json.Unmarshal(data, &podcast); err != nil {
return nil, err
}
return &podcast, nil
}
type ElevenLabsTTS struct {
apiKey string
httpClient *http.Client
baseURL string
}
func NewElevenLabsTTS(apiKey string) *ElevenLabsTTS {
return &ElevenLabsTTS{
apiKey: apiKey,
httpClient: &http.Client{Timeout: 120 * time.Second},
baseURL: "https://api.elevenlabs.io/v1",
}
}
func (t *ElevenLabsTTS) GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error) {
voiceID := config.VoiceID
if voiceID == "" {
voiceID = "21m00Tcm4TlvDq8ikWAM"
}
url := fmt.Sprintf("%s/text-to-speech/%s", t.baseURL, voiceID)
body := map[string]interface{}{
"text": text,
"model_id": "eleven_multilingual_v2",
"voice_settings": map[string]interface{}{
"stability": 0.5,
"similarity_boost": 0.75,
"style": 0.5,
"use_speaker_boost": true,
},
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyJSON)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("xi-api-key", t.apiKey)
req.Header.Set("Accept", "audio/mpeg")
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ElevenLabs API error: %d", resp.StatusCode)
}
var audioData []byte
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
audioData = append(audioData, buf[:n]...)
}
if err != nil {
break
}
}
return audioData, nil
}
type DummyTTS struct{}
func (t *DummyTTS) GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error) {
return []byte{}, nil
}

View File

@@ -0,0 +1,50 @@
package prompts
import "strings"
func GetClassifierPrompt(locale, detectedLang string) string {
langInstruction := "Respond in the same language as the user's query."
if detectedLang == "ru" {
langInstruction = "The user is writing in Russian. Process accordingly."
}
return strings.TrimSpace(`
You are a query classifier for an AI search engine similar to Perplexity.
Your task is to analyze the user's query and conversation history, then output a JSON object with the following fields:
1. "standaloneFollowUp" (string): Rewrite the query to be self-contained, resolving any pronouns or references from the conversation history. If the query is already standalone, return it as-is.
2. "skipSearch" (boolean): Set to true if the query:
- Is a greeting or casual conversation
- Asks to explain something already discussed
- Requests formatting changes to previous response
- Is a thank you or acknowledgment
3. "topics" (array of strings): Key topics or entities mentioned in the query.
4. "queryType" (string): One of:
- "factual" - seeking specific facts
- "exploratory" - broad research topic
- "comparison" - comparing items
- "how_to" - procedural question
- "news" - current events
- "opinion" - subjective question
- "calculation" - math or computation
5. "engines" (array of strings): Suggested search engines based on query type.
` + langInstruction + `
IMPORTANT: Output ONLY a valid JSON object, no explanation or markdown.
Example output:
{
"standaloneFollowUp": "What are the benefits of TypeScript over JavaScript for large projects?",
"skipSearch": false,
"topics": ["TypeScript", "JavaScript", "programming"],
"queryType": "comparison",
"engines": ["google", "duckduckgo"]
}
`)
}

View File

@@ -0,0 +1,127 @@
package prompts
import (
"fmt"
"strings"
)
type ResearcherConfig struct {
AvailableActions string
Mode string
Iteration int
MaxIterations int
Locale string
DetectedLanguage string
IsArticleSummary bool
}
func GetResearcherPrompt(cfg ResearcherConfig) string {
var sb strings.Builder
sb.WriteString("You are a research agent for GooSeek, an AI search engine.\n\n")
sb.WriteString("## Your Role\n\n")
sb.WriteString("You gather information to answer user queries by:\n")
sb.WriteString("1. Searching the web for relevant information\n")
sb.WriteString("2. Scraping specific pages for detailed content\n")
sb.WriteString("3. Deciding when you have enough information\n\n")
sb.WriteString("## Available Actions\n\n")
sb.WriteString(cfg.AvailableActions)
sb.WriteString("\n\n")
sb.WriteString("## Progress\n\n")
sb.WriteString(fmt.Sprintf("Current iteration: %d / %d\n\n", cfg.Iteration+1, cfg.MaxIterations))
switch cfg.Mode {
case "speed":
sb.WriteString("## Speed Mode\n\n")
sb.WriteString("- Perform ONE search and call done\n")
sb.WriteString("- Do NOT scrape pages\n")
sb.WriteString("- Use snippets from search results\n\n")
case "balanced":
sb.WriteString("## Balanced Mode\n\n")
sb.WriteString("- Perform 1-3 searches\n")
sb.WriteString("- Scrape top 3-5 relevant pages\n")
sb.WriteString("- Balance depth vs. speed\n\n")
case "quality":
sb.WriteString("## Quality Mode\n\n")
sb.WriteString("- Perform multiple searches with different queries\n")
sb.WriteString("- Scrape 10-15 relevant pages\n")
sb.WriteString("- Verify information across sources\n")
sb.WriteString("- Be thorough and comprehensive\n\n")
}
if cfg.IsArticleSummary {
sb.WriteString("## Article Summary Task (Perplexity Discover-style)\n\n")
sb.WriteString("The user requested an article summary (Summary: <url>). This is a multi-source digest request.\n\n")
sb.WriteString("**Your goals:**\n")
sb.WriteString("1. The main article is already pre-scraped and will be in context\n")
sb.WriteString("2. Search for 3-5 related sources that provide context\n")
sb.WriteString("3. Look for: related news, background, analysis, reactions\n")
sb.WriteString("4. Use news categories: `news`, `science` engines\n")
sb.WriteString("5. Max 5 additional sources (article itself is [1])\n\n")
sb.WriteString("**Search strategy:**\n")
sb.WriteString("- Extract key entities/topics from article title\n")
sb.WriteString("- Search for recent news on those topics\n")
sb.WriteString("- Find expert opinions or analysis\n")
sb.WriteString("- Look for official statements if relevant\n\n")
}
if cfg.DetectedLanguage == "ru" {
sb.WriteString("## Language\n\n")
sb.WriteString("Пользователь пишет на русском. Формулируй поисковые запросы на русском языке.\n\n")
}
sb.WriteString("## Instructions\n\n")
sb.WriteString("1. Analyze the user's query and conversation history\n")
sb.WriteString("2. Plan what information you need to gather\n")
sb.WriteString("3. Execute actions to gather that information\n")
sb.WriteString("4. Call 'done' when you have sufficient information\n\n")
sb.WriteString("## Important Rules\n\n")
sb.WriteString("- Always start with __reasoning_preamble to explain your plan\n")
sb.WriteString("- Formulate specific, targeted search queries\n")
sb.WriteString("- Avoid redundant searches\n")
sb.WriteString("- Call 'done' when information is sufficient\n")
sb.WriteString("- Don't exceed the iteration limit\n\n")
sb.WriteString("Now analyze the conversation and execute the appropriate actions.")
return sb.String()
}
func GetAvailableActionsDescription() string {
return strings.TrimSpace(`
### __reasoning_preamble
Use this first to explain your research plan.
Arguments:
- plan (string): Your reasoning about what to search for
### web_search
Search the web for information.
Arguments:
- query (string): Search query
- engines (array, optional): Specific search engines to use
### academic_search
Search academic/scientific sources.
Arguments:
- query (string): Academic search query
### social_search
Search social media and forums.
Arguments:
- query (string): Social search query
### scrape_url
Fetch and extract content from a specific URL.
Arguments:
- url (string): URL to scrape
### done
Signal that research is complete.
Arguments:
- reason (string): Why research is sufficient
`)
}

View File

@@ -0,0 +1,146 @@
package prompts
import (
"fmt"
"strings"
)
type WriterConfig struct {
Context string
SystemInstructions string
Mode string
Locale string
MemoryContext string
AnswerMode string
ResponsePrefs *ResponsePrefs
DetectedLanguage string
IsArticleSummary bool
LearningMode bool
}
type ResponsePrefs struct {
Format string
Length string
Tone string
}
func GetWriterPrompt(cfg WriterConfig) string {
var sb strings.Builder
sb.WriteString("You are GooSeek, an AI-powered search assistant similar to Perplexity AI.\n\n")
if cfg.DetectedLanguage == "ru" {
sb.WriteString("ВАЖНО: Пользователь пишет на русском языке. Отвечай ТОЛЬКО на русском языке.\n\n")
}
sb.WriteString("## Core Instructions\n\n")
sb.WriteString("1. **Always cite sources** using [number] format, e.g., [1], [2]. Citations must reference the search results provided.\n")
sb.WriteString("2. **Be comprehensive** but concise. Provide thorough answers with key information.\n")
sb.WriteString("3. **Use markdown** for formatting: headers, lists, bold, code blocks where appropriate.\n")
sb.WriteString("4. **Be objective** and factual. Present information neutrally.\n")
sb.WriteString("5. **Acknowledge limitations** if search results are insufficient.\n\n")
if cfg.IsArticleSummary {
sb.WriteString("## Article Summary Mode (Perplexity-style Digest)\n\n")
sb.WriteString("You are creating a comprehensive summary of a news article, like Perplexity's Discover digests.\n\n")
sb.WriteString("**Structure your response as:**\n")
sb.WriteString("1. **Headline summary** (1-2 sentences capturing the essence)\n")
sb.WriteString("2. **Key points** with citations [1], [2], etc.\n")
sb.WriteString("3. **Context and background** from related sources\n")
sb.WriteString("4. **Analysis/implications** if relevant\n")
sb.WriteString("5. **Related questions** the reader might have (as > quoted lines)\n\n")
sb.WriteString("**Rules:**\n")
sb.WriteString("- Always cite sources [1], [2], etc.\n")
sb.WriteString("- First source [1] is usually the main article\n")
sb.WriteString("- Add context from other sources [2], [3], etc.\n")
sb.WriteString("- End with 2-3 follow-up questions prefixed with >\n")
sb.WriteString("- Write in the user's language (Russian if they use Russian)\n\n")
}
switch cfg.Mode {
case "speed":
sb.WriteString("## Speed Mode\n\n")
sb.WriteString("Provide a quick, focused answer. Be concise (2-3 paragraphs max).\n")
sb.WriteString("Prioritize the most relevant information.\n\n")
case "balanced":
sb.WriteString("## Balanced Mode\n\n")
sb.WriteString("Provide a well-rounded answer with moderate detail.\n")
sb.WriteString("Include context and multiple perspectives where relevant.\n\n")
case "quality":
sb.WriteString("## Quality Mode\n\n")
sb.WriteString("Provide a comprehensive, in-depth analysis.\n")
sb.WriteString("Include detailed explanations, examples, and nuances.\n")
sb.WriteString("Cover multiple aspects of the topic.\n\n")
}
if cfg.AnswerMode != "" && cfg.AnswerMode != "standard" {
sb.WriteString(fmt.Sprintf("## Answer Mode: %s\n\n", cfg.AnswerMode))
sb.WriteString(getAnswerModeInstructions(cfg.AnswerMode))
}
if cfg.ResponsePrefs != nil {
sb.WriteString("## Response Preferences\n\n")
if cfg.ResponsePrefs.Format != "" {
sb.WriteString(fmt.Sprintf("- Format: %s\n", cfg.ResponsePrefs.Format))
}
if cfg.ResponsePrefs.Length != "" {
sb.WriteString(fmt.Sprintf("- Length: %s\n", cfg.ResponsePrefs.Length))
}
if cfg.ResponsePrefs.Tone != "" {
sb.WriteString(fmt.Sprintf("- Tone: %s\n", cfg.ResponsePrefs.Tone))
}
sb.WriteString("\n")
}
if cfg.MemoryContext != "" {
sb.WriteString("## User Context (from memory)\n\n")
sb.WriteString(cfg.MemoryContext)
sb.WriteString("\n\n")
}
if cfg.SystemInstructions != "" && cfg.SystemInstructions != "None" {
sb.WriteString("## Custom Instructions\n\n")
sb.WriteString(cfg.SystemInstructions)
sb.WriteString("\n\n")
}
if cfg.LearningMode {
sb.WriteString("## Learning Mode\n\n")
sb.WriteString("The user is in learning mode. Explain concepts thoroughly.\n")
sb.WriteString("Use analogies, examples, and break down complex topics.\n")
sb.WriteString("Ask clarifying questions if the topic is ambiguous.\n\n")
}
sb.WriteString("## Search Results\n\n")
sb.WriteString(cfg.Context)
sb.WriteString("\n\n")
sb.WriteString("## Citation Rules\n\n")
sb.WriteString("- Use [1], [2], etc. to cite sources from the search results\n")
sb.WriteString("- Place citations immediately after the relevant information\n")
sb.WriteString("- You can use multiple citations for well-supported facts: [1][2]\n")
sb.WriteString("- Do not cite widgets or generated content\n")
sb.WriteString("- If no relevant source exists, don't make up citations\n\n")
sb.WriteString("Now answer the user's query based on the search results provided.")
return sb.String()
}
func getAnswerModeInstructions(mode string) string {
instructions := map[string]string{
"academic": "Focus on scholarly sources, research papers, and academic perspectives. Use formal language and cite peer-reviewed sources when available.\n\n",
"writing": "Help with writing tasks. Provide suggestions for structure, style, and content. Be creative and helpful.\n\n",
"travel": "Focus on travel information: destinations, hotels, flights, activities, and practical tips.\n\n",
"finance": "Provide financial information carefully. Include disclaimers about not being financial advice. Focus on factual data.\n\n",
"health": "Provide health information from reliable sources. Always recommend consulting healthcare professionals. Be cautious and accurate.\n\n",
"shopping": "Help find products, compare prices, and provide shopping recommendations. Include product features and user reviews.\n\n",
"news": "Focus on current events and recent news. Provide multiple perspectives and fact-check information.\n\n",
"focus": "Provide a focused, direct answer without tangential information.\n\n",
}
if inst, ok := instructions[mode]; ok {
return inst
}
return ""
}

View File

@@ -0,0 +1,215 @@
package search
import (
"context"
"regexp"
"strconv"
"strings"
"github.com/gooseek/backend/internal/types"
)
type MediaSearchOptions struct {
MaxImages int
MaxVideos int
}
type MediaSearchResult struct {
Images []types.ImageData `json:"images"`
Videos []types.VideoData `json:"videos"`
}
func (c *SearXNGClient) SearchMedia(ctx context.Context, query string, opts *MediaSearchOptions) (*MediaSearchResult, error) {
if opts == nil {
opts = &MediaSearchOptions{MaxImages: 8, MaxVideos: 6}
}
result := &MediaSearchResult{
Images: make([]types.ImageData, 0),
Videos: make([]types.VideoData, 0),
}
imageCh := make(chan []types.ImageData, 1)
videoCh := make(chan []types.VideoData, 1)
errCh := make(chan error, 2)
go func() {
images, err := c.searchImages(ctx, query, opts.MaxImages)
if err != nil {
errCh <- err
imageCh <- nil
return
}
errCh <- nil
imageCh <- images
}()
go func() {
videos, err := c.searchVideos(ctx, query, opts.MaxVideos)
if err != nil {
errCh <- err
videoCh <- nil
return
}
errCh <- nil
videoCh <- videos
}()
<-errCh
<-errCh
result.Images = <-imageCh
result.Videos = <-videoCh
if result.Images == nil {
result.Images = make([]types.ImageData, 0)
}
if result.Videos == nil {
result.Videos = make([]types.VideoData, 0)
}
return result, nil
}
func (c *SearXNGClient) searchImages(ctx context.Context, query string, max int) ([]types.ImageData, error) {
resp, err := c.Search(ctx, query, &SearchOptions{
Categories: []string{"images"},
PageNo: 1,
})
if err != nil {
return nil, err
}
images := make([]types.ImageData, 0, max)
seen := make(map[string]bool)
for _, r := range resp.Results {
if len(images) >= max {
break
}
imgURL := r.ImgSrc
if imgURL == "" {
imgURL = r.ThumbnailSrc
}
if imgURL == "" {
imgURL = r.Thumbnail
}
if imgURL == "" {
continue
}
if seen[imgURL] {
continue
}
seen[imgURL] = true
images = append(images, types.ImageData{
URL: imgURL,
Title: r.Title,
Source: extractDomain(r.URL),
SourceURL: r.URL,
})
}
return images, nil
}
func (c *SearXNGClient) searchVideos(ctx context.Context, query string, max int) ([]types.VideoData, error) {
resp, err := c.Search(ctx, query, &SearchOptions{
Categories: []string{"videos"},
PageNo: 1,
})
if err != nil {
return nil, err
}
videos := make([]types.VideoData, 0, max)
seen := make(map[string]bool)
for _, r := range resp.Results {
if len(videos) >= max {
break
}
if seen[r.URL] {
continue
}
seen[r.URL] = true
platform := detectVideoPlatform(r.URL)
video := types.VideoData{
Title: r.Title,
URL: r.URL,
Thumbnail: r.Thumbnail,
Duration: toInt(r.Duration),
Views: toInt(r.Views),
Author: r.Author,
Platform: platform,
EmbedURL: r.IframeSrc,
}
videos = append(videos, video)
}
return videos, nil
}
var (
youtubePattern = regexp.MustCompile(`youtube\.com|youtu\.be`)
rutubePattern = regexp.MustCompile(`rutube\.ru`)
vkPattern = regexp.MustCompile(`vk\.com`)
dzenPattern = regexp.MustCompile(`dzen\.ru`)
)
func detectVideoPlatform(url string) string {
urlLower := strings.ToLower(url)
if youtubePattern.MatchString(urlLower) {
return "youtube"
}
if rutubePattern.MatchString(urlLower) {
return "rutube"
}
if vkPattern.MatchString(urlLower) {
return "vk"
}
if dzenPattern.MatchString(urlLower) {
return "dzen"
}
return "other"
}
func extractDomain(rawURL string) string {
rawURL = strings.TrimPrefix(rawURL, "https://")
rawURL = strings.TrimPrefix(rawURL, "http://")
rawURL = strings.TrimPrefix(rawURL, "www.")
if idx := strings.Index(rawURL, "/"); idx > 0 {
rawURL = rawURL[:idx]
}
return rawURL
}
func toInt(v interface{}) int {
if v == nil {
return 0
}
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
if i, err := strconv.Atoi(val); err == nil {
return i
}
return 0
default:
return 0
}
}

View File

@@ -0,0 +1,163 @@
package search
import (
"math"
"sort"
"strings"
"unicode"
"github.com/gooseek/backend/internal/types"
)
type RankedItem struct {
Chunk types.Chunk
Score float64
}
func RerankBM25(chunks []types.Chunk, query string, topK int) []types.Chunk {
if len(chunks) == 0 {
return chunks
}
queryTerms := tokenize(query)
if len(queryTerms) == 0 {
return chunks
}
df := make(map[string]int)
for _, chunk := range chunks {
seen := make(map[string]bool)
terms := tokenize(chunk.Content + " " + chunk.Metadata["title"])
for _, term := range terms {
if !seen[term] {
df[term]++
seen[term] = true
}
}
}
avgDocLen := 0.0
for _, chunk := range chunks {
avgDocLen += float64(len(tokenize(chunk.Content)))
}
avgDocLen /= float64(len(chunks))
k1 := 1.5
b := 0.75
n := float64(len(chunks))
ranked := make([]RankedItem, len(chunks))
for i, chunk := range chunks {
docTerms := tokenize(chunk.Content + " " + chunk.Metadata["title"])
docLen := float64(len(docTerms))
tf := make(map[string]int)
for _, term := range docTerms {
tf[term]++
}
score := 0.0
for _, qterm := range queryTerms {
if termFreq, ok := tf[qterm]; ok {
docFreq := float64(df[qterm])
idf := math.Log((n - docFreq + 0.5) / (docFreq + 0.5))
if idf < 0 {
idf = 0
}
tfNorm := float64(termFreq) * (k1 + 1) /
(float64(termFreq) + k1*(1-b+b*docLen/avgDocLen))
score += idf * tfNorm
}
}
if title, ok := chunk.Metadata["title"]; ok {
titleLower := strings.ToLower(title)
for _, qterm := range queryTerms {
if strings.Contains(titleLower, qterm) {
score += 2.0
}
}
}
ranked[i] = RankedItem{Chunk: chunk, Score: score}
}
sort.Slice(ranked, func(i, j int) bool {
return ranked[i].Score > ranked[j].Score
})
if topK > len(ranked) {
topK = len(ranked)
}
result := make([]types.Chunk, topK)
for i := 0; i < topK; i++ {
result[i] = ranked[i].Chunk
}
return result
}
func tokenize(text string) []string {
text = strings.ToLower(text)
var tokens []string
var current strings.Builder
for _, r := range text {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
current.WriteRune(r)
} else {
if current.Len() >= 2 {
tokens = append(tokens, current.String())
}
current.Reset()
}
}
if current.Len() >= 2 {
tokens = append(tokens, current.String())
}
return tokens
}
func EstimateQueryComplexity(query string) float64 {
terms := tokenize(query)
complexity := float64(len(terms)) / 5.0
if strings.Contains(query, "?") {
complexity += 0.2
}
if strings.Contains(query, " и ") || strings.Contains(query, " или ") {
complexity += 0.3
}
if complexity > 1.0 {
complexity = 1.0
}
return complexity
}
func ComputeAdaptiveTopK(totalResults int, complexity float64, mode string) int {
baseK := 15
switch mode {
case "speed":
baseK = 10
case "balanced":
baseK = 20
case "quality":
baseK = 30
}
adaptiveK := int(float64(baseK) * (1 + complexity*0.5))
if adaptiveK > totalResults {
adaptiveK = totalResults
}
return adaptiveK
}

View File

@@ -0,0 +1,177 @@
package search
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gooseek/backend/internal/types"
"github.com/gooseek/backend/pkg/config"
)
type SearXNGClient struct {
primaryURL string
fallbackURLs []string
client *http.Client
timeout time.Duration
}
func NewSearXNGClient(cfg *config.Config) *SearXNGClient {
return &SearXNGClient{
primaryURL: cfg.SearXNGURL,
fallbackURLs: cfg.SearXNGFallbackURL,
client: &http.Client{Timeout: cfg.SearchTimeout},
timeout: cfg.SearchTimeout,
}
}
type SearchOptions struct {
Engines []string
Categories []string
PageNo int
Language string
}
func (c *SearXNGClient) Search(ctx context.Context, query string, opts *SearchOptions) (*types.SearchResponse, error) {
candidates := c.buildCandidates()
if len(candidates) == 0 {
return nil, fmt.Errorf("no SearXNG URLs configured")
}
var lastErr error
for _, baseURL := range candidates {
result, err := c.searchWithURL(ctx, baseURL, query, opts)
if err == nil {
return result, nil
}
lastErr = err
}
return nil, fmt.Errorf("all SearXNG instances failed: %w", lastErr)
}
func (c *SearXNGClient) buildCandidates() []string {
candidates := make([]string, 0)
if c.primaryURL != "" {
u := strings.TrimSuffix(c.primaryURL, "/")
if !strings.HasPrefix(u, "http") {
u = "http://" + u
}
candidates = append(candidates, u)
}
for _, fb := range c.fallbackURLs {
u := strings.TrimSpace(fb)
if u == "" {
continue
}
u = strings.TrimSuffix(u, "/")
if !strings.HasPrefix(u, "http") {
u = "https://" + u
}
if !contains(candidates, u) {
candidates = append(candidates, u)
}
}
return candidates
}
func (c *SearXNGClient) searchWithURL(ctx context.Context, baseURL, query string, opts *SearchOptions) (*types.SearchResponse, error) {
params := url.Values{}
params.Set("format", "json")
params.Set("q", query)
if opts != nil {
if len(opts.Engines) > 0 {
params.Set("engines", strings.Join(opts.Engines, ","))
}
if len(opts.Categories) > 0 {
params.Set("categories", strings.Join(opts.Categories, ","))
}
if opts.PageNo > 0 {
params.Set("pageno", fmt.Sprintf("%d", opts.PageNo))
}
if opts.Language != "" {
params.Set("language", opts.Language)
}
}
reqURL := fmt.Sprintf("%s/search?%s", baseURL, params.Encode())
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return nil, err
}
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
}
var result struct {
Results []types.SearchResult `json:"results"`
Suggestions []string `json:"suggestions"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &types.SearchResponse{
Results: result.Results,
Suggestions: result.Suggestions,
}, nil
}
var (
productPattern = regexp.MustCompile(`ozon\.ru/product|wildberries\.ru/catalog/\d|aliexpress\.(ru|com)/item|market\.yandex`)
videoPattern = regexp.MustCompile(`rutube\.ru/video|vk\.com/video|vk\.com/clip|youtube\.com/watch|youtu\.be|dzen\.ru/video`)
vkProfilePattern = regexp.MustCompile(`vk\.com/[a-zA-Z0-9_.]+$`)
tgProfilePattern = regexp.MustCompile(`t\.me/[a-zA-Z0-9_]+$`)
)
func CategorizeResult(result *types.SearchResult) types.ContentCategory {
urlLower := strings.ToLower(result.URL)
if productPattern.MatchString(urlLower) {
return types.CategoryProduct
}
if videoPattern.MatchString(urlLower) || result.IframeSrc != "" || result.Category == "videos" {
return types.CategoryVideo
}
if tgProfilePattern.MatchString(urlLower) {
return types.CategoryProfile
}
if vkProfilePattern.MatchString(urlLower) && !videoPattern.MatchString(urlLower) {
return types.CategoryProfile
}
if result.ImgSrc != "" && result.Category == "images" {
return types.CategoryImage
}
return types.CategoryArticle
}
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

View File

@@ -0,0 +1,183 @@
package session
import (
"encoding/json"
"sync"
"github.com/gooseek/backend/internal/types"
"github.com/google/uuid"
)
type EventType string
const (
EventData EventType = "data"
EventEnd EventType = "end"
EventError EventType = "error"
)
type Event struct {
Type EventType `json:"type"`
Data interface{} `json:"data"`
}
type Subscriber func(event EventType, data interface{})
type Session struct {
id string
blocks map[string]*types.Block
subscribers []Subscriber
mu sync.RWMutex
closed bool
}
func NewSession() *Session {
return &Session{
id: uuid.New().String(),
blocks: make(map[string]*types.Block),
subscribers: make([]Subscriber, 0),
}
}
func (s *Session) ID() string {
return s.id
}
func (s *Session) Subscribe(fn Subscriber) func() {
s.mu.Lock()
s.subscribers = append(s.subscribers, fn)
idx := len(s.subscribers) - 1
s.mu.Unlock()
return func() {
s.mu.Lock()
defer s.mu.Unlock()
if idx < len(s.subscribers) {
s.subscribers = append(s.subscribers[:idx], s.subscribers[idx+1:]...)
}
}
}
func (s *Session) Emit(eventType EventType, data interface{}) {
s.mu.RLock()
if s.closed {
s.mu.RUnlock()
return
}
subs := make([]Subscriber, len(s.subscribers))
copy(subs, s.subscribers)
s.mu.RUnlock()
for _, sub := range subs {
sub(eventType, data)
}
}
func (s *Session) EmitBlock(block *types.Block) {
s.mu.Lock()
s.blocks[block.ID] = block
s.mu.Unlock()
s.Emit(EventData, map[string]interface{}{
"type": "block",
"block": block,
})
}
func (s *Session) UpdateBlock(blockID string, patches []Patch) {
s.mu.Lock()
block, ok := s.blocks[blockID]
if !ok {
s.mu.Unlock()
return
}
for _, patch := range patches {
applyPatch(block, patch)
}
s.mu.Unlock()
s.Emit(EventData, map[string]interface{}{
"type": "updateBlock",
"blockId": blockID,
"patch": patches,
})
}
func (s *Session) EmitTextChunk(blockID, chunk string) {
s.Emit(EventData, map[string]interface{}{
"type": "textChunk",
"blockId": blockID,
"chunk": chunk,
})
}
func (s *Session) EmitResearchComplete() {
s.Emit(EventData, map[string]interface{}{
"type": "researchComplete",
})
}
func (s *Session) EmitEnd() {
s.Emit(EventData, map[string]interface{}{
"type": "messageEnd",
})
s.Emit(EventEnd, nil)
}
func (s *Session) EmitError(err error) {
s.Emit(EventData, map[string]interface{}{
"type": "error",
"data": err.Error(),
})
s.Emit(EventError, map[string]interface{}{
"data": err.Error(),
})
}
func (s *Session) GetBlock(id string) *types.Block {
s.mu.RLock()
defer s.mu.RUnlock()
return s.blocks[id]
}
func (s *Session) Close() {
s.mu.Lock()
s.closed = true
s.subscribers = nil
s.mu.Unlock()
}
func (s *Session) RemoveAllListeners() {
s.mu.Lock()
s.subscribers = nil
s.mu.Unlock()
}
type Patch struct {
Op string `json:"op"`
Path string `json:"path"`
Value interface{} `json:"value"`
}
func applyPatch(block *types.Block, patch Patch) {
if patch.Op != "replace" {
return
}
switch patch.Path {
case "/data":
block.Data = patch.Value
case "/data/subSteps":
if rd, ok := block.Data.(types.ResearchData); ok {
if steps, ok := patch.Value.([]types.ResearchSubStep); ok {
rd.SubSteps = steps
block.Data = rd
}
}
}
}
func MarshalEvent(data interface{}) ([]byte, error) {
return json.Marshal(data)
}

View File

@@ -0,0 +1,102 @@
package types
type BlockType string
const (
BlockTypeText BlockType = "text"
BlockTypeResearch BlockType = "research"
BlockTypeSource BlockType = "source"
BlockTypeWidget BlockType = "widget"
BlockTypeThinking BlockType = "thinking"
)
type Block struct {
ID string `json:"id"`
Type BlockType `json:"type"`
Data interface{} `json:"data"`
}
type TextBlock struct {
ID string `json:"id"`
Type string `json:"type"`
Data string `json:"data"`
}
type ResearchBlock struct {
ID string `json:"id"`
Type string `json:"type"`
Data ResearchData `json:"data"`
}
type ResearchData struct {
SubSteps []ResearchSubStep `json:"subSteps"`
}
type ResearchSubStep struct {
ID string `json:"id"`
Type string `json:"type"`
Reasoning string `json:"reasoning,omitempty"`
Searching []string `json:"searching,omitempty"`
Reading []Chunk `json:"reading,omitempty"`
}
type SourceBlock struct {
ID string `json:"id"`
Type string `json:"type"`
Data []Chunk `json:"data"`
}
type WidgetBlock struct {
ID string `json:"id"`
Type string `json:"type"`
Data WidgetData `json:"data"`
}
type WidgetData struct {
WidgetType string `json:"widgetType"`
Params interface{} `json:"params"`
}
type StreamEvent struct {
Type string `json:"type"`
Block *Block `json:"block,omitempty"`
BlockID string `json:"blockId,omitempty"`
Chunk string `json:"chunk,omitempty"`
Patch interface{} `json:"patch,omitempty"`
Data interface{} `json:"data,omitempty"`
}
func NewTextBlock(id, content string) *Block {
return &Block{
ID: id,
Type: BlockTypeText,
Data: content,
}
}
func NewResearchBlock(id string) *Block {
return &Block{
ID: id,
Type: BlockTypeResearch,
Data: ResearchData{SubSteps: []ResearchSubStep{}},
}
}
func NewSourceBlock(id string, chunks []Chunk) *Block {
return &Block{
ID: id,
Type: BlockTypeSource,
Data: chunks,
}
}
func NewWidgetBlock(id, widgetType string, params interface{}) *Block {
return &Block{
ID: id,
Type: BlockTypeWidget,
Data: WidgetData{
WidgetType: widgetType,
Params: params,
},
}
}

View File

@@ -0,0 +1,75 @@
package types
type Chunk struct {
Content string `json:"content"`
Metadata map[string]string `json:"metadata,omitempty"`
}
type SearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content,omitempty"`
Thumbnail string `json:"thumbnail,omitempty"`
ImgSrc string `json:"img_src,omitempty"`
ThumbnailSrc string `json:"thumbnail_src,omitempty"`
IframeSrc string `json:"iframe_src,omitempty"`
Author string `json:"author,omitempty"`
PublishedDate string `json:"publishedDate,omitempty"`
Engine string `json:"engine,omitempty"`
Category string `json:"category,omitempty"`
Score float64 `json:"score,omitempty"`
Price string `json:"price,omitempty"`
Currency string `json:"currency,omitempty"`
Duration interface{} `json:"duration,omitempty"`
Views interface{} `json:"views,omitempty"`
}
type SearchResponse struct {
Results []SearchResult `json:"results"`
Suggestions []string `json:"suggestions,omitempty"`
}
type ContentCategory string
const (
CategoryProduct ContentCategory = "product"
CategoryVideo ContentCategory = "video"
CategoryProfile ContentCategory = "profile"
CategoryPromo ContentCategory = "promo"
CategoryImage ContentCategory = "image"
CategoryArticle ContentCategory = "article"
)
func (r *SearchResult) ToChunk() Chunk {
metadata := map[string]string{
"title": r.Title,
"url": r.URL,
}
if r.Thumbnail != "" {
metadata["thumbnail"] = r.Thumbnail
}
if r.Author != "" {
metadata["author"] = r.Author
}
if r.PublishedDate != "" {
metadata["publishedDate"] = r.PublishedDate
}
content := r.Content
if content == "" {
content = r.Title
}
return Chunk{
Content: content,
Metadata: metadata,
}
}
func SearchResultsToChunks(results []SearchResult) []Chunk {
chunks := make([]Chunk, 0, len(results))
for _, r := range results {
chunks = append(chunks, r.ToChunk())
}
return chunks
}

View File

@@ -0,0 +1,145 @@
package types
type WidgetType string
const (
WidgetWeather WidgetType = "weather"
WidgetCalculator WidgetType = "calculator"
WidgetProducts WidgetType = "products"
WidgetVideos WidgetType = "videos"
WidgetProfiles WidgetType = "profiles"
WidgetPromos WidgetType = "promos"
WidgetImageGallery WidgetType = "image_gallery"
WidgetVideoEmbed WidgetType = "video_embed"
WidgetKnowledge WidgetType = "knowledge_card"
)
type ProductData struct {
Title string `json:"title"`
URL string `json:"url"`
Price float64 `json:"price"`
OldPrice float64 `json:"oldPrice,omitempty"`
Currency string `json:"currency"`
Discount int `json:"discount,omitempty"`
Rating float64 `json:"rating,omitempty"`
ReviewCount int `json:"reviewCount,omitempty"`
ImageURL string `json:"imageUrl,omitempty"`
Marketplace string `json:"marketplace"`
InStock bool `json:"inStock"`
Badges []Badge `json:"badges,omitempty"`
}
type VideoData struct {
Title string `json:"title"`
URL string `json:"url"`
Thumbnail string `json:"thumbnail,omitempty"`
Duration int `json:"duration,omitempty"`
Views int `json:"views,omitempty"`
Likes int `json:"likes,omitempty"`
Author string `json:"author,omitempty"`
Platform string `json:"platform"`
EmbedURL string `json:"embedUrl,omitempty"`
}
type ProfileData struct {
Name string `json:"name"`
Username string `json:"username,omitempty"`
URL string `json:"url"`
AvatarURL string `json:"avatarUrl,omitempty"`
Bio string `json:"bio,omitempty"`
Followers int `json:"followers,omitempty"`
Following int `json:"following,omitempty"`
Platform string `json:"platform"`
Verified bool `json:"verified"`
IsOnline bool `json:"isOnline,omitempty"`
LastOnline string `json:"lastOnline,omitempty"`
}
type PromoData struct {
Code string `json:"code"`
Description string `json:"description"`
Discount string `json:"discount"`
Store string `json:"store"`
StoreURL string `json:"storeUrl"`
LogoURL string `json:"logoUrl,omitempty"`
ExpiresAt string `json:"expiresAt,omitempty"`
Conditions string `json:"conditions,omitempty"`
Verified bool `json:"verified"`
}
type ImageData struct {
URL string `json:"url"`
Title string `json:"title,omitempty"`
Source string `json:"source,omitempty"`
SourceURL string `json:"sourceUrl,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
}
type Badge struct {
Text string `json:"text"`
Type string `json:"type"`
Color string `json:"color,omitempty"`
}
type KnowledgeCardData struct {
Type string `json:"type"`
Title string `json:"title,omitempty"`
Content interface{} `json:"content"`
}
type ComparisonTable struct {
Headers []string `json:"headers"`
Rows [][]string `json:"rows"`
}
type StatCard struct {
Label string `json:"label"`
Value string `json:"value"`
Change float64 `json:"change,omitempty"`
Unit string `json:"unit,omitempty"`
}
type Timeline struct {
Events []TimelineEvent `json:"events"`
}
type TimelineEvent struct {
Date string `json:"date"`
Title string `json:"title"`
Description string `json:"description,omitempty"`
}
type WeatherParams struct {
Location string `json:"location"`
Current WeatherCurrent `json:"current"`
Forecast []WeatherDay `json:"forecast,omitempty"`
LastUpdated string `json:"lastUpdated,omitempty"`
}
type WeatherCurrent struct {
Temp float64 `json:"temp"`
FeelsLike float64 `json:"feelsLike"`
Humidity int `json:"humidity"`
WindSpeed float64 `json:"windSpeed"`
Description string `json:"description"`
Icon string `json:"icon"`
}
type WeatherDay struct {
Date string `json:"date"`
TempMax float64 `json:"tempMax"`
TempMin float64 `json:"tempMin"`
Icon string `json:"icon"`
}
type CalculatorParams struct {
Expression string `json:"expression"`
Result float64 `json:"result"`
Steps []Step `json:"steps,omitempty"`
}
type Step struct {
Description string `json:"description"`
Value string `json:"value"`
}