feat: Go backend, enhanced search, new widgets, Docker deploy
Major changes: - Add Go backend (backend/) with microservices architecture - Enhanced master-agents-svc: reranker, content-classifier, stealth-crawler, proxy-manager, media-search, fastClassifier, language detection - New web-svc widgets: KnowledgeCard, ProductCard, ProfileCard, VideoCard, UnifiedCard, CardGallery, InlineImageGallery, SourcesPanel, RelatedQuestions - Improved discover-svc with discover-db integration - Docker deployment improvements (Caddyfile, vendor.sh, BUILD.md) - Library-svc: project_id schema migration - Remove deprecated finance-svc and travel-svc - Localization improvements across services Made-with: Cursor
This commit is contained in:
233
backend/internal/agent/classifier.go
Normal file
233
backend/internal/agent/classifier.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/prompts"
|
||||
)
|
||||
|
||||
type ClassificationResult struct {
|
||||
StandaloneFollowUp string `json:"standaloneFollowUp"`
|
||||
SkipSearch bool `json:"skipSearch"`
|
||||
Topics []string `json:"topics,omitempty"`
|
||||
QueryType string `json:"queryType,omitempty"`
|
||||
Engines []string `json:"engines,omitempty"`
|
||||
}
|
||||
|
||||
func classify(ctx context.Context, client llm.Client, query string, history []llm.Message, locale, detectedLang string) (*ClassificationResult, error) {
|
||||
prompt := prompts.GetClassifierPrompt(locale, detectedLang)
|
||||
|
||||
historyStr := formatHistory(history)
|
||||
userContent := "<conversation>\n" + historyStr + "\nUser: " + query + "\n</conversation>"
|
||||
|
||||
messages := []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: prompt},
|
||||
{Role: llm.RoleUser, Content: userContent},
|
||||
}
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: messages,
|
||||
Options: llm.StreamOptions{MaxTokens: 1024},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonMatch := regexp.MustCompile(`\{[\s\S]*\}`).FindString(response)
|
||||
if jsonMatch == "" {
|
||||
return &ClassificationResult{
|
||||
StandaloneFollowUp: query,
|
||||
SkipSearch: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result ClassificationResult
|
||||
if err := json.Unmarshal([]byte(jsonMatch), &result); err != nil {
|
||||
return &ClassificationResult{
|
||||
StandaloneFollowUp: query,
|
||||
SkipSearch: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if result.StandaloneFollowUp == "" {
|
||||
result.StandaloneFollowUp = query
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func fastClassify(query string, history []llm.Message) *ClassificationResult {
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
skipPatterns := []string{
|
||||
"привет", "как дела", "спасибо", "пока",
|
||||
"hello", "hi", "thanks", "bye",
|
||||
"объясни", "расскажи подробнее", "что ты имеешь",
|
||||
}
|
||||
|
||||
skipSearch := false
|
||||
for _, p := range skipPatterns {
|
||||
if strings.Contains(queryLower, p) && len(query) < 50 {
|
||||
skipSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
standalone := query
|
||||
|
||||
if len(history) > 0 {
|
||||
pronouns := []string{
|
||||
"это", "этот", "эта", "эти",
|
||||
"он", "она", "оно", "они",
|
||||
"it", "this", "that", "they", "them",
|
||||
}
|
||||
|
||||
hasPronouns := false
|
||||
for _, p := range pronouns {
|
||||
if strings.Contains(queryLower, p+" ") || strings.HasPrefix(queryLower, p+" ") {
|
||||
hasPronouns = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasPronouns && len(history) >= 2 {
|
||||
lastAssistant := ""
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == llm.RoleAssistant {
|
||||
lastAssistant = history[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastAssistant != "" {
|
||||
topics := extractTopics(lastAssistant)
|
||||
if len(topics) > 0 {
|
||||
standalone = query + " (контекст: " + strings.Join(topics, ", ") + ")"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
engines := detectEngines(queryLower)
|
||||
|
||||
return &ClassificationResult{
|
||||
StandaloneFollowUp: standalone,
|
||||
SkipSearch: skipSearch,
|
||||
Engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
func generateSearchQueries(query string) []string {
|
||||
queries := []string{query}
|
||||
|
||||
if len(query) > 100 {
|
||||
words := strings.Fields(query)
|
||||
if len(words) > 5 {
|
||||
queries = append(queries, strings.Join(words[:5], " "))
|
||||
}
|
||||
}
|
||||
|
||||
keywordPatterns := []string{
|
||||
"как", "что такое", "где", "когда", "почему", "кто",
|
||||
"how", "what is", "where", "when", "why", "who",
|
||||
}
|
||||
|
||||
for _, p := range keywordPatterns {
|
||||
if strings.HasPrefix(strings.ToLower(query), p) {
|
||||
withoutPrefix := strings.TrimPrefix(strings.ToLower(query), p)
|
||||
withoutPrefix = strings.TrimSpace(withoutPrefix)
|
||||
if len(withoutPrefix) > 10 {
|
||||
queries = append(queries, withoutPrefix)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(queries) > 3 {
|
||||
queries = queries[:3]
|
||||
}
|
||||
|
||||
return queries
|
||||
}
|
||||
|
||||
func detectEngines(query string) []string {
|
||||
engines := []string{"google", "duckduckgo"}
|
||||
|
||||
if strings.Contains(query, "новости") || strings.Contains(query, "news") {
|
||||
engines = append(engines, "google_news")
|
||||
}
|
||||
|
||||
if strings.Contains(query, "видео") || strings.Contains(query, "video") {
|
||||
engines = append(engines, "youtube")
|
||||
}
|
||||
|
||||
if strings.Contains(query, "товар") || strings.Contains(query, "купить") ||
|
||||
strings.Contains(query, "цена") || strings.Contains(query, "price") {
|
||||
engines = append(engines, "google_shopping")
|
||||
}
|
||||
|
||||
return engines
|
||||
}
|
||||
|
||||
func extractTopics(text string) []string {
|
||||
words := strings.Fields(text)
|
||||
if len(words) > 50 {
|
||||
words = words[:50]
|
||||
}
|
||||
|
||||
topics := make([]string, 0)
|
||||
for _, w := range words {
|
||||
if len(w) > 5 && len(w) < 20 {
|
||||
r := []rune(w)
|
||||
if len(r) > 0 && ((r[0] >= 'A' && r[0] <= 'Z') || (r[0] >= 'А' && r[0] <= 'Я')) {
|
||||
topics = append(topics, w)
|
||||
if len(topics) >= 3 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return topics
|
||||
}
|
||||
|
||||
func formatHistory(messages []llm.Message) string {
|
||||
var sb strings.Builder
|
||||
for _, m := range messages {
|
||||
role := "User"
|
||||
if m.Role == llm.RoleAssistant {
|
||||
role = "Assistant"
|
||||
}
|
||||
sb.WriteString(role)
|
||||
sb.WriteString(": ")
|
||||
content := m.Content
|
||||
if len(content) > 500 {
|
||||
content = content[:500] + "..."
|
||||
}
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func detectLanguage(text string) string {
|
||||
cyrillicCount := 0
|
||||
latinCount := 0
|
||||
|
||||
for _, r := range text {
|
||||
if r >= 'а' && r <= 'я' || r >= 'А' && r <= 'Я' {
|
||||
cyrillicCount++
|
||||
} else if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' {
|
||||
latinCount++
|
||||
}
|
||||
}
|
||||
|
||||
if cyrillicCount > latinCount {
|
||||
return "ru"
|
||||
}
|
||||
return "en"
|
||||
}
|
||||
543
backend/internal/agent/deep_research.go
Normal file
543
backend/internal/agent/deep_research.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/search"
|
||||
"github.com/gooseek/backend/internal/session"
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type DeepResearchConfig struct {
|
||||
LLM llm.Client
|
||||
SearchClient *search.SearXNGClient
|
||||
FocusMode FocusMode
|
||||
Locale string
|
||||
MaxSearchQueries int
|
||||
MaxSources int
|
||||
MaxIterations int
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type DeepResearchResult struct {
|
||||
FinalReport string
|
||||
Sources []types.Chunk
|
||||
SubQueries []SubQuery
|
||||
Insights []string
|
||||
FollowUpQueries []string
|
||||
TotalSearches int
|
||||
TotalSources int
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
type SubQuery struct {
|
||||
Query string
|
||||
Purpose string
|
||||
Status string
|
||||
Results []types.Chunk
|
||||
Insights []string
|
||||
}
|
||||
|
||||
type DeepResearcher struct {
|
||||
cfg DeepResearchConfig
|
||||
sess *session.Session
|
||||
mu sync.Mutex
|
||||
allSources []types.Chunk
|
||||
seenURLs map[string]bool
|
||||
subQueries []SubQuery
|
||||
insights []string
|
||||
searchCount int
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
func NewDeepResearcher(cfg DeepResearchConfig, sess *session.Session) *DeepResearcher {
|
||||
if cfg.MaxSearchQueries == 0 {
|
||||
cfg.MaxSearchQueries = 30
|
||||
}
|
||||
if cfg.MaxSources == 0 {
|
||||
cfg.MaxSources = 100
|
||||
}
|
||||
if cfg.MaxIterations == 0 {
|
||||
cfg.MaxIterations = 5
|
||||
}
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &DeepResearcher{
|
||||
cfg: cfg,
|
||||
sess: sess,
|
||||
seenURLs: make(map[string]bool),
|
||||
allSources: make([]types.Chunk, 0),
|
||||
subQueries: make([]SubQuery, 0),
|
||||
insights: make([]string, 0),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) Research(ctx context.Context, query string) (*DeepResearchResult, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, dr.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
researchBlockID := uuid.New().String()
|
||||
dr.sess.EmitBlock(&types.Block{
|
||||
ID: researchBlockID,
|
||||
Type: types.BlockTypeResearch,
|
||||
Data: types.ResearchData{
|
||||
SubSteps: []types.ResearchSubStep{},
|
||||
},
|
||||
})
|
||||
|
||||
subQueries, err := dr.planResearch(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("planning failed: %w", err)
|
||||
}
|
||||
|
||||
dr.updateResearchStatus(researchBlockID, "researching", fmt.Sprintf("Executing %d sub-queries", len(subQueries)))
|
||||
|
||||
for i := 0; i < dr.cfg.MaxIterations && dr.searchCount < dr.cfg.MaxSearchQueries; i++ {
|
||||
if err := dr.executeIteration(ctx, i, researchBlockID); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dr.hasEnoughData() {
|
||||
break
|
||||
}
|
||||
|
||||
newQueries, err := dr.generateFollowUpQueries(ctx, query)
|
||||
if err != nil || len(newQueries) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, q := range newQueries {
|
||||
dr.mu.Lock()
|
||||
dr.subQueries = append(dr.subQueries, SubQuery{
|
||||
Query: q.Query,
|
||||
Purpose: q.Purpose,
|
||||
Status: "pending",
|
||||
})
|
||||
dr.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
dr.updateResearchStatus(researchBlockID, "synthesizing", "Analyzing findings")
|
||||
|
||||
insights, err := dr.synthesizeInsights(ctx, query)
|
||||
if err != nil {
|
||||
insights = dr.insights
|
||||
}
|
||||
|
||||
dr.updateResearchStatus(researchBlockID, "writing", "Generating report")
|
||||
|
||||
report, err := dr.generateFinalReport(ctx, query, insights)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("report generation failed: %w", err)
|
||||
}
|
||||
|
||||
followUp, _ := dr.generateFollowUpSuggestions(ctx, query, report)
|
||||
|
||||
dr.updateResearchStatus(researchBlockID, "complete", "Research complete")
|
||||
|
||||
return &DeepResearchResult{
|
||||
FinalReport: report,
|
||||
Sources: dr.allSources,
|
||||
SubQueries: dr.subQueries,
|
||||
Insights: insights,
|
||||
FollowUpQueries: followUp,
|
||||
TotalSearches: dr.searchCount,
|
||||
TotalSources: len(dr.allSources),
|
||||
Duration: time.Since(dr.startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) planResearch(ctx context.Context, query string) ([]SubQuery, error) {
|
||||
prompt := fmt.Sprintf(`Analyze this research query and break it into 3-5 sub-queries for comprehensive research.
|
||||
|
||||
Query: %s
|
||||
|
||||
For each sub-query, specify:
|
||||
1. The search query (optimized for search engines)
|
||||
2. The purpose (what aspect it addresses)
|
||||
|
||||
Respond in this exact format:
|
||||
QUERY: [search query]
|
||||
PURPOSE: [what this addresses]
|
||||
|
||||
QUERY: [search query]
|
||||
PURPOSE: [what this addresses]
|
||||
|
||||
...
|
||||
|
||||
Be specific and actionable. Focus on different aspects: definitions, current state, history, expert opinions, data/statistics, controversies, future trends.`, query)
|
||||
|
||||
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return dr.generateDefaultSubQueries(query), nil
|
||||
}
|
||||
|
||||
subQueries := dr.parseSubQueries(result)
|
||||
if len(subQueries) == 0 {
|
||||
subQueries = dr.generateDefaultSubQueries(query)
|
||||
}
|
||||
|
||||
dr.mu.Lock()
|
||||
dr.subQueries = subQueries
|
||||
dr.mu.Unlock()
|
||||
|
||||
return subQueries, nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) parseSubQueries(text string) []SubQuery {
|
||||
var queries []SubQuery
|
||||
lines := strings.Split(text, "\n")
|
||||
|
||||
var currentQuery, currentPurpose string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "QUERY:") {
|
||||
if currentQuery != "" && currentPurpose != "" {
|
||||
queries = append(queries, SubQuery{
|
||||
Query: currentQuery,
|
||||
Purpose: currentPurpose,
|
||||
Status: "pending",
|
||||
})
|
||||
}
|
||||
currentQuery = strings.TrimSpace(strings.TrimPrefix(line, "QUERY:"))
|
||||
currentPurpose = ""
|
||||
} else if strings.HasPrefix(line, "PURPOSE:") {
|
||||
currentPurpose = strings.TrimSpace(strings.TrimPrefix(line, "PURPOSE:"))
|
||||
}
|
||||
}
|
||||
|
||||
if currentQuery != "" && currentPurpose != "" {
|
||||
queries = append(queries, SubQuery{
|
||||
Query: currentQuery,
|
||||
Purpose: currentPurpose,
|
||||
Status: "pending",
|
||||
})
|
||||
}
|
||||
|
||||
return queries
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) generateDefaultSubQueries(query string) []SubQuery {
|
||||
return []SubQuery{
|
||||
{Query: query, Purpose: "Main query", Status: "pending"},
|
||||
{Query: query + " definition explained", Purpose: "Definitions and basics", Status: "pending"},
|
||||
{Query: query + " latest news 2026", Purpose: "Current developments", Status: "pending"},
|
||||
{Query: query + " expert analysis", Purpose: "Expert opinions", Status: "pending"},
|
||||
{Query: query + " statistics data research", Purpose: "Data and evidence", Status: "pending"},
|
||||
}
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) executeIteration(ctx context.Context, iteration int, blockID string) error {
|
||||
dr.mu.Lock()
|
||||
pendingQueries := make([]int, 0)
|
||||
for i, sq := range dr.subQueries {
|
||||
if sq.Status == "pending" {
|
||||
pendingQueries = append(pendingQueries, i)
|
||||
}
|
||||
}
|
||||
dr.mu.Unlock()
|
||||
|
||||
if len(pendingQueries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
batchSize := 3
|
||||
if len(pendingQueries) < batchSize {
|
||||
batchSize = len(pendingQueries)
|
||||
}
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(batchSize)
|
||||
|
||||
for _, idx := range pendingQueries[:batchSize] {
|
||||
idx := idx
|
||||
g.Go(func() error {
|
||||
return dr.executeSubQuery(gctx, idx, blockID)
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) executeSubQuery(ctx context.Context, idx int, blockID string) error {
|
||||
dr.mu.Lock()
|
||||
if idx >= len(dr.subQueries) {
|
||||
dr.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
sq := &dr.subQueries[idx]
|
||||
sq.Status = "searching"
|
||||
query := sq.Query
|
||||
dr.searchCount++
|
||||
dr.mu.Unlock()
|
||||
|
||||
dr.updateResearchStatus(blockID, "researching", fmt.Sprintf("Searching: %s", truncate(query, 50)))
|
||||
|
||||
enhancedQuery := EnhanceQueryForFocusMode(query, dr.cfg.FocusMode)
|
||||
|
||||
results, err := dr.cfg.SearchClient.Search(ctx, enhancedQuery, &search.SearchOptions{
|
||||
Engines: dr.cfg.FocusMode.GetSearchEngines(),
|
||||
Categories: FocusModeConfigs[dr.cfg.FocusMode].Categories,
|
||||
PageNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
dr.mu.Lock()
|
||||
sq.Status = "failed"
|
||||
dr.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
chunks := make([]types.Chunk, 0)
|
||||
for _, r := range results.Results {
|
||||
dr.mu.Lock()
|
||||
if dr.seenURLs[r.URL] {
|
||||
dr.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
dr.seenURLs[r.URL] = true
|
||||
dr.mu.Unlock()
|
||||
|
||||
chunk := r.ToChunk()
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
if len(chunks) >= 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dr.mu.Lock()
|
||||
sq.Results = chunks
|
||||
sq.Status = "complete"
|
||||
dr.allSources = append(dr.allSources, chunks...)
|
||||
dr.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) generateFollowUpQueries(ctx context.Context, originalQuery string) ([]SubQuery, error) {
|
||||
if dr.searchCount >= dr.cfg.MaxSearchQueries-5 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sourceSummary strings.Builder
|
||||
dr.mu.Lock()
|
||||
for i, s := range dr.allSources {
|
||||
if i >= 20 {
|
||||
break
|
||||
}
|
||||
sourceSummary.WriteString(fmt.Sprintf("- %s: %s\n", s.Metadata["title"], truncate(s.Content, 100)))
|
||||
}
|
||||
dr.mu.Unlock()
|
||||
|
||||
prompt := fmt.Sprintf(`Based on the original query and sources found so far, suggest 2-3 follow-up queries to deepen the research.
|
||||
|
||||
Original query: %s
|
||||
|
||||
Sources found so far:
|
||||
%s
|
||||
|
||||
What aspects are missing? What would provide more comprehensive coverage?
|
||||
Respond with queries in format:
|
||||
QUERY: [query]
|
||||
PURPOSE: [what gap it fills]`, originalQuery, sourceSummary.String())
|
||||
|
||||
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dr.parseSubQueries(result), nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) synthesizeInsights(ctx context.Context, query string) ([]string, error) {
|
||||
var sourcesText strings.Builder
|
||||
dr.mu.Lock()
|
||||
for i, s := range dr.allSources {
|
||||
if i >= 30 {
|
||||
break
|
||||
}
|
||||
sourcesText.WriteString(fmt.Sprintf("[%d] %s\n%s\n\n", i+1, s.Metadata["title"], truncate(s.Content, 300)))
|
||||
}
|
||||
dr.mu.Unlock()
|
||||
|
||||
prompt := fmt.Sprintf(`Analyze these sources and extract 5-7 key insights for the query: %s
|
||||
|
||||
Sources:
|
||||
%s
|
||||
|
||||
Provide insights as bullet points, each starting with a key finding.
|
||||
Focus on: main conclusions, patterns, contradictions, expert consensus, data points.`, query, sourcesText.String())
|
||||
|
||||
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insights := make([]string, 0)
|
||||
for _, line := range strings.Split(result, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "•") || strings.HasPrefix(line, "*") {
|
||||
insights = append(insights, strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "•"), "*"))
|
||||
}
|
||||
}
|
||||
|
||||
dr.mu.Lock()
|
||||
dr.insights = insights
|
||||
dr.mu.Unlock()
|
||||
|
||||
return insights, nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) generateFinalReport(ctx context.Context, query string, insights []string) (string, error) {
|
||||
var sourcesText strings.Builder
|
||||
dr.mu.Lock()
|
||||
sources := dr.allSources
|
||||
dr.mu.Unlock()
|
||||
|
||||
for i, s := range sources {
|
||||
if i >= 50 {
|
||||
break
|
||||
}
|
||||
sourcesText.WriteString(fmt.Sprintf("[%d] %s (%s)\n%s\n\n", i+1, s.Metadata["title"], s.Metadata["url"], truncate(s.Content, 400)))
|
||||
}
|
||||
|
||||
insightsText := strings.Join(insights, "\n- ")
|
||||
|
||||
focusCfg := FocusModeConfigs[dr.cfg.FocusMode]
|
||||
locale := dr.cfg.Locale
|
||||
if locale == "" {
|
||||
locale = "en"
|
||||
}
|
||||
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Write the report in Russian."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`%s
|
||||
|
||||
Write a comprehensive research report answering: %s
|
||||
|
||||
Key insights discovered:
|
||||
- %s
|
||||
|
||||
Sources (cite using [1], [2], etc.):
|
||||
%s
|
||||
|
||||
Structure your report with:
|
||||
1. Executive Summary (2-3 sentences)
|
||||
2. Key Findings (organized by theme)
|
||||
3. Analysis and Discussion
|
||||
4. Conclusions
|
||||
|
||||
%s
|
||||
Use citations [1], [2], etc. throughout.
|
||||
Be thorough but concise. Focus on actionable information.`, focusCfg.SystemPrompt, query, insightsText, sourcesText.String(), langInstruction)
|
||||
|
||||
stream, err := dr.cfg.LLM.StreamText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var report strings.Builder
|
||||
textBlockID := uuid.New().String()
|
||||
dr.sess.EmitBlock(&types.Block{
|
||||
ID: textBlockID,
|
||||
Type: types.BlockTypeText,
|
||||
Data: "",
|
||||
})
|
||||
|
||||
for chunk := range stream {
|
||||
report.WriteString(chunk.ContentChunk)
|
||||
dr.sess.EmitTextChunk(textBlockID, chunk.ContentChunk)
|
||||
}
|
||||
|
||||
return report.String(), nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) generateFollowUpSuggestions(ctx context.Context, query, report string) ([]string, error) {
|
||||
prompt := fmt.Sprintf(`Based on this research query and report, suggest 3-4 follow-up questions the user might want to explore:
|
||||
|
||||
Query: %s
|
||||
|
||||
Report summary: %s
|
||||
|
||||
Provide follow-up questions that:
|
||||
1. Go deeper into specific aspects
|
||||
2. Explore related topics
|
||||
3. Address practical applications
|
||||
4. Consider alternative perspectives
|
||||
|
||||
Format as simple questions, one per line.`, query, truncate(report, 1000))
|
||||
|
||||
result, err := dr.cfg.LLM.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
suggestions := make([]string, 0)
|
||||
for _, line := range strings.Split(result, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && (strings.Contains(line, "?") || len(line) > 20) {
|
||||
line = strings.TrimPrefix(line, "- ")
|
||||
line = strings.TrimPrefix(line, "• ")
|
||||
line = strings.TrimLeft(line, "0123456789. ")
|
||||
if line != "" {
|
||||
suggestions = append(suggestions, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(suggestions) > 4 {
|
||||
suggestions = suggestions[:4]
|
||||
}
|
||||
|
||||
return suggestions, nil
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) updateResearchStatus(blockID, status, message string) {
|
||||
dr.sess.UpdateBlock(blockID, []session.Patch{
|
||||
{Op: "replace", Path: "/data/status", Value: status},
|
||||
{Op: "replace", Path: "/data/message", Value: message},
|
||||
})
|
||||
}
|
||||
|
||||
func (dr *DeepResearcher) hasEnoughData() bool {
|
||||
dr.mu.Lock()
|
||||
defer dr.mu.Unlock()
|
||||
return len(dr.allSources) >= dr.cfg.MaxSources
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func RunDeepResearch(ctx context.Context, sess *session.Session, query string, cfg DeepResearchConfig) (*DeepResearchResult, error) {
|
||||
researcher := NewDeepResearcher(cfg, sess)
|
||||
return researcher.Research(ctx, query)
|
||||
}
|
||||
293
backend/internal/agent/focus_modes.go
Normal file
293
backend/internal/agent/focus_modes.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FocusMode string
|
||||
|
||||
const (
|
||||
FocusModeAll FocusMode = "all"
|
||||
FocusModeAcademic FocusMode = "academic"
|
||||
FocusModeWriting FocusMode = "writing"
|
||||
FocusModeYouTube FocusMode = "youtube"
|
||||
FocusModeReddit FocusMode = "reddit"
|
||||
FocusModeCode FocusMode = "code"
|
||||
FocusModeNews FocusMode = "news"
|
||||
FocusModeImages FocusMode = "images"
|
||||
FocusModeMath FocusMode = "math"
|
||||
FocusModeFinance FocusMode = "finance"
|
||||
)
|
||||
|
||||
type FocusModeConfig struct {
|
||||
Mode FocusMode
|
||||
Engines []string
|
||||
Categories []string
|
||||
SystemPrompt string
|
||||
SearchQueryPrefix string
|
||||
MaxSources int
|
||||
RequiresCitation bool
|
||||
AllowScraping bool
|
||||
}
|
||||
|
||||
var FocusModeConfigs = map[FocusMode]FocusModeConfig{
|
||||
FocusModeAll: {
|
||||
Mode: FocusModeAll,
|
||||
Engines: []string{"google", "bing", "duckduckgo"},
|
||||
Categories: []string{"general"},
|
||||
MaxSources: 15,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are a helpful AI assistant that provides comprehensive answers based on web search results.
|
||||
Always cite your sources using [1], [2], etc. format.
|
||||
Provide balanced, accurate information from multiple perspectives.`,
|
||||
},
|
||||
FocusModeAcademic: {
|
||||
Mode: FocusModeAcademic,
|
||||
Engines: []string{"google scholar", "arxiv", "pubmed", "semantic scholar"},
|
||||
Categories: []string{"science"},
|
||||
SearchQueryPrefix: "research paper",
|
||||
MaxSources: 20,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are an academic research assistant specializing in scholarly sources.
|
||||
Focus on peer-reviewed papers, academic journals, and reputable research institutions.
|
||||
Always cite sources in academic format with [1], [2], etc.
|
||||
Distinguish between primary research, meta-analyses, and review articles.
|
||||
Mention publication dates, authors, and journals when available.
|
||||
Be precise about confidence levels and note when findings are preliminary or contested.`,
|
||||
},
|
||||
FocusModeWriting: {
|
||||
Mode: FocusModeWriting,
|
||||
Engines: []string{"google"},
|
||||
Categories: []string{"general"},
|
||||
MaxSources: 5,
|
||||
RequiresCitation: false,
|
||||
AllowScraping: false,
|
||||
SystemPrompt: `You are a creative writing assistant.
|
||||
Help with drafting, editing, and improving written content.
|
||||
Provide suggestions for style, tone, structure, and clarity.
|
||||
Offer multiple variations when appropriate.
|
||||
Focus on the user's voice and intent rather than web search results.`,
|
||||
},
|
||||
FocusModeYouTube: {
|
||||
Mode: FocusModeYouTube,
|
||||
Engines: []string{"youtube"},
|
||||
Categories: []string{"videos"},
|
||||
SearchQueryPrefix: "site:youtube.com",
|
||||
MaxSources: 10,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: false,
|
||||
SystemPrompt: `You are a video content assistant focused on YouTube.
|
||||
Summarize video content, recommend relevant videos, and help find tutorials.
|
||||
Mention video titles, channels, and approximate timestamps when relevant.
|
||||
Note view counts and upload dates to indicate video popularity and relevance.`,
|
||||
},
|
||||
FocusModeReddit: {
|
||||
Mode: FocusModeReddit,
|
||||
Engines: []string{"reddit"},
|
||||
Categories: []string{"social media"},
|
||||
SearchQueryPrefix: "site:reddit.com",
|
||||
MaxSources: 15,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are an assistant that specializes in Reddit discussions and community knowledge.
|
||||
Focus on highly upvoted comments and posts from relevant subreddits.
|
||||
Note the subreddit source, upvote counts, and community consensus.
|
||||
Distinguish between personal opinions, experiences, and factual claims.
|
||||
Be aware of potential biases in specific communities.`,
|
||||
},
|
||||
FocusModeCode: {
|
||||
Mode: FocusModeCode,
|
||||
Engines: []string{"google", "github", "stackoverflow"},
|
||||
Categories: []string{"it"},
|
||||
SearchQueryPrefix: "",
|
||||
MaxSources: 10,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are a programming assistant focused on code, documentation, and technical solutions.
|
||||
Provide working code examples with explanations.
|
||||
Reference official documentation, Stack Overflow answers, and GitHub repositories.
|
||||
Mention library versions and compatibility considerations.
|
||||
Follow best practices and coding standards for the relevant language/framework.
|
||||
Include error handling and edge cases in code examples.`,
|
||||
},
|
||||
FocusModeNews: {
|
||||
Mode: FocusModeNews,
|
||||
Engines: []string{"google news", "bing news"},
|
||||
Categories: []string{"news"},
|
||||
MaxSources: 12,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are a news assistant that provides current events information.
|
||||
Focus on recent, verified news from reputable sources.
|
||||
Distinguish between breaking news, analysis, and opinion pieces.
|
||||
Note publication dates and source credibility.
|
||||
Present multiple perspectives on controversial topics.`,
|
||||
},
|
||||
FocusModeImages: {
|
||||
Mode: FocusModeImages,
|
||||
Engines: []string{"google images", "bing images"},
|
||||
Categories: []string{"images"},
|
||||
MaxSources: 20,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: false,
|
||||
SystemPrompt: `You are an image search assistant.
|
||||
Help find relevant images, describe image sources, and provide context.
|
||||
Note image sources, licenses, and quality when relevant.`,
|
||||
},
|
||||
FocusModeMath: {
|
||||
Mode: FocusModeMath,
|
||||
Engines: []string{"wolfram alpha", "google"},
|
||||
Categories: []string{"science"},
|
||||
MaxSources: 5,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: false,
|
||||
SystemPrompt: `You are a mathematical problem-solving assistant.
|
||||
Provide step-by-step solutions with clear explanations.
|
||||
Use proper mathematical notation and formatting.
|
||||
Show your work and explain the reasoning behind each step.
|
||||
Mention relevant theorems, formulas, and mathematical concepts.
|
||||
Verify your calculations and provide alternative solution methods when applicable.`,
|
||||
},
|
||||
FocusModeFinance: {
|
||||
Mode: FocusModeFinance,
|
||||
Engines: []string{"google", "google finance", "yahoo finance"},
|
||||
Categories: []string{"news"},
|
||||
SearchQueryPrefix: "stock market finance",
|
||||
MaxSources: 10,
|
||||
RequiresCitation: true,
|
||||
AllowScraping: true,
|
||||
SystemPrompt: `You are a financial information assistant.
|
||||
Provide accurate financial data, market analysis, and investment information.
|
||||
Note that you cannot provide personalized financial advice.
|
||||
Cite data sources and note when data may be delayed or historical.
|
||||
Include relevant disclaimers about investment risks.
|
||||
Reference SEC filings, analyst reports, and official company statements.`,
|
||||
},
|
||||
}
|
||||
|
||||
func GetFocusModeConfig(mode string) FocusModeConfig {
|
||||
fm := FocusMode(strings.ToLower(mode))
|
||||
if cfg, ok := FocusModeConfigs[fm]; ok {
|
||||
return cfg
|
||||
}
|
||||
return FocusModeConfigs[FocusModeAll]
|
||||
}
|
||||
|
||||
func DetectFocusMode(query string) FocusMode {
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
academicKeywords := []string{
|
||||
"research", "paper", "study", "journal", "scientific", "academic",
|
||||
"peer-reviewed", "citation", "исследование", "научн", "статья",
|
||||
"публикация", "диссертация",
|
||||
}
|
||||
for _, kw := range academicKeywords {
|
||||
if strings.Contains(queryLower, kw) {
|
||||
return FocusModeAcademic
|
||||
}
|
||||
}
|
||||
|
||||
codeKeywords := []string{
|
||||
"code", "programming", "function", "error", "bug", "api",
|
||||
"library", "framework", "syntax", "compile", "debug",
|
||||
"код", "программ", "функция", "ошибка", "библиотека",
|
||||
"golang", "python", "javascript", "typescript", "react", "vue",
|
||||
"docker", "kubernetes", "sql", "database", "git",
|
||||
}
|
||||
for _, kw := range codeKeywords {
|
||||
if strings.Contains(queryLower, kw) {
|
||||
return FocusModeCode
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(queryLower, "youtube") ||
|
||||
strings.Contains(queryLower, "video tutorial") ||
|
||||
strings.Contains(queryLower, "видео") {
|
||||
return FocusModeYouTube
|
||||
}
|
||||
|
||||
if strings.Contains(queryLower, "reddit") ||
|
||||
strings.Contains(queryLower, "subreddit") ||
|
||||
strings.Contains(queryLower, "/r/") {
|
||||
return FocusModeReddit
|
||||
}
|
||||
|
||||
mathKeywords := []string{
|
||||
"calculate", "solve", "equation", "integral", "derivative",
|
||||
"formula", "theorem", "proof", "вычисл", "решить", "уравнение",
|
||||
"интеграл", "производная", "формула", "теорема",
|
||||
}
|
||||
for _, kw := range mathKeywords {
|
||||
if strings.Contains(queryLower, kw) {
|
||||
return FocusModeMath
|
||||
}
|
||||
}
|
||||
|
||||
financeKeywords := []string{
|
||||
"stock", "market", "invest", "price", "trading", "finance",
|
||||
"акци", "рынок", "инвест", "биржа", "котировк", "финанс",
|
||||
"etf", "dividend", "portfolio",
|
||||
}
|
||||
for _, kw := range financeKeywords {
|
||||
if strings.Contains(queryLower, kw) {
|
||||
return FocusModeFinance
|
||||
}
|
||||
}
|
||||
|
||||
newsKeywords := []string{
|
||||
"news", "today", "latest", "breaking", "current events",
|
||||
"новост", "сегодня", "последн", "актуальн",
|
||||
}
|
||||
for _, kw := range newsKeywords {
|
||||
if strings.Contains(queryLower, kw) {
|
||||
return FocusModeNews
|
||||
}
|
||||
}
|
||||
|
||||
return FocusModeAll
|
||||
}
|
||||
|
||||
func (f FocusMode) GetSearchEngines() []string {
|
||||
if cfg, ok := FocusModeConfigs[f]; ok {
|
||||
return cfg.Engines
|
||||
}
|
||||
return FocusModeConfigs[FocusModeAll].Engines
|
||||
}
|
||||
|
||||
func (f FocusMode) GetSystemPrompt() string {
|
||||
if cfg, ok := FocusModeConfigs[f]; ok {
|
||||
return cfg.SystemPrompt
|
||||
}
|
||||
return FocusModeConfigs[FocusModeAll].SystemPrompt
|
||||
}
|
||||
|
||||
func (f FocusMode) GetMaxSources() int {
|
||||
if cfg, ok := FocusModeConfigs[f]; ok {
|
||||
return cfg.MaxSources
|
||||
}
|
||||
return 15
|
||||
}
|
||||
|
||||
func (f FocusMode) RequiresCitation() bool {
|
||||
if cfg, ok := FocusModeConfigs[f]; ok {
|
||||
return cfg.RequiresCitation
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (f FocusMode) AllowsScraping() bool {
|
||||
if cfg, ok := FocusModeConfigs[f]; ok {
|
||||
return cfg.AllowScraping
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func EnhanceQueryForFocusMode(query string, mode FocusMode) string {
|
||||
cfg := FocusModeConfigs[mode]
|
||||
if cfg.SearchQueryPrefix != "" {
|
||||
return cfg.SearchQueryPrefix + " " + query
|
||||
}
|
||||
return query
|
||||
}
|
||||
950
backend/internal/agent/orchestrator.go
Normal file
950
backend/internal/agent/orchestrator.go
Normal file
@@ -0,0 +1,950 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/prompts"
|
||||
"github.com/gooseek/backend/internal/search"
|
||||
"github.com/gooseek/backend/internal/session"
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeSpeed Mode = "speed"
|
||||
ModeBalanced Mode = "balanced"
|
||||
ModeQuality Mode = "quality"
|
||||
)
|
||||
|
||||
type OrchestratorConfig struct {
|
||||
LLM llm.Client
|
||||
SearchClient *search.SearXNGClient
|
||||
Mode Mode
|
||||
FocusMode FocusMode
|
||||
Sources []string
|
||||
FileIDs []string
|
||||
FileContext string
|
||||
CollectionID string
|
||||
CollectionContext string
|
||||
SystemInstructions string
|
||||
Locale string
|
||||
MemoryContext string
|
||||
UserMemory string
|
||||
AnswerMode string
|
||||
ResponsePrefs *ResponsePrefs
|
||||
LearningMode bool
|
||||
EnableDeepResearch bool
|
||||
EnableClarifying bool
|
||||
DiscoverSvcURL string
|
||||
Crawl4AIURL string
|
||||
CollectionSvcURL string
|
||||
FileSvcURL string
|
||||
}
|
||||
|
||||
type DigestResponse struct {
|
||||
SummaryRu string `json:"summaryRu"`
|
||||
Citations []DigestCitation `json:"citations"`
|
||||
FollowUp []string `json:"followUp"`
|
||||
SourcesCount int `json:"sourcesCount"`
|
||||
ClusterTitle string `json:"clusterTitle"`
|
||||
}
|
||||
|
||||
type DigestCitation struct {
|
||||
Index int `json:"index"`
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type PreScrapedArticle struct {
|
||||
Title string
|
||||
Content string
|
||||
URL string
|
||||
}
|
||||
|
||||
type ResponsePrefs struct {
|
||||
Format string `json:"format,omitempty"`
|
||||
Length string `json:"length,omitempty"`
|
||||
Tone string `json:"tone,omitempty"`
|
||||
}
|
||||
|
||||
type OrchestratorInput struct {
|
||||
ChatHistory []llm.Message
|
||||
FollowUp string
|
||||
Config OrchestratorConfig
|
||||
}
|
||||
|
||||
func RunOrchestrator(ctx context.Context, sess *session.Session, input OrchestratorInput) error {
|
||||
detectedLang := detectLanguage(input.FollowUp)
|
||||
isArticleSummary := strings.HasPrefix(strings.TrimSpace(input.FollowUp), "Summary: ")
|
||||
|
||||
if input.Config.FocusMode == "" {
|
||||
input.Config.FocusMode = DetectFocusMode(input.FollowUp)
|
||||
}
|
||||
|
||||
if input.Config.EnableDeepResearch && input.Config.Mode == ModeQuality {
|
||||
return runDeepResearchMode(ctx, sess, input, detectedLang)
|
||||
}
|
||||
|
||||
if input.Config.Mode == ModeSpeed && !isArticleSummary {
|
||||
return runSpeedMode(ctx, sess, input, detectedLang)
|
||||
}
|
||||
|
||||
return runFullMode(ctx, sess, input, detectedLang, isArticleSummary)
|
||||
}
|
||||
|
||||
func runDeepResearchMode(ctx context.Context, sess *session.Session, input OrchestratorInput, lang string) error {
|
||||
sess.EmitBlock(types.NewResearchBlock(uuid.New().String()))
|
||||
|
||||
researcher := NewDeepResearcher(DeepResearchConfig{
|
||||
LLM: input.Config.LLM,
|
||||
SearchClient: input.Config.SearchClient,
|
||||
FocusMode: input.Config.FocusMode,
|
||||
Locale: input.Config.Locale,
|
||||
MaxSearchQueries: 30,
|
||||
MaxSources: 100,
|
||||
MaxIterations: 5,
|
||||
Timeout: 5 * time.Minute,
|
||||
}, sess)
|
||||
|
||||
result, err := researcher.Research(ctx, input.FollowUp)
|
||||
if err != nil {
|
||||
sess.EmitError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), result.Sources))
|
||||
|
||||
if len(result.FollowUpQueries) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
|
||||
"questions": result.FollowUpQueries,
|
||||
}))
|
||||
}
|
||||
|
||||
sess.EmitResearchComplete()
|
||||
sess.EmitEnd()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateClarifyingQuestions(ctx context.Context, llmClient llm.Client, query string) ([]string, error) {
|
||||
prompt := fmt.Sprintf(`Analyze this query and determine if clarifying questions would help provide a better answer.
|
||||
|
||||
Query: %s
|
||||
|
||||
If the query is:
|
||||
- Clear and specific → respond with "CLEAR"
|
||||
- Ambiguous or could benefit from clarification → provide 2-3 short clarifying questions
|
||||
|
||||
Format:
|
||||
CLEAR
|
||||
or
|
||||
QUESTION: [question 1]
|
||||
QUESTION: [question 2]
|
||||
QUESTION: [question 3]`, query)
|
||||
|
||||
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToUpper(result), "CLEAR") {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var questions []string
|
||||
for _, line := range strings.Split(result, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "QUESTION:") {
|
||||
q := strings.TrimSpace(strings.TrimPrefix(line, "QUESTION:"))
|
||||
if q != "" {
|
||||
questions = append(questions, q)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return questions, nil
|
||||
}
|
||||
|
||||
func generateRelatedQuestions(ctx context.Context, llmClient llm.Client, query, answer string, locale string) []string {
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Generate questions in Russian."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Based on this query and answer, generate 3-4 related follow-up questions the user might want to explore.
|
||||
|
||||
Query: %s
|
||||
|
||||
Answer summary: %s
|
||||
|
||||
%s
|
||||
Format: One question per line, no numbering or bullets.`, query, truncateForPrompt(answer, 500), langInstruction)
|
||||
|
||||
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var questions []string
|
||||
for _, line := range strings.Split(result, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && len(line) > 10 && strings.Contains(line, "?") {
|
||||
line = strings.TrimLeft(line, "0123456789.-•* ")
|
||||
if line != "" {
|
||||
questions = append(questions, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(questions) > 4 {
|
||||
questions = questions[:4]
|
||||
}
|
||||
|
||||
return questions
|
||||
}
|
||||
|
||||
func truncateForPrompt(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func buildEnhancedContext(input OrchestratorInput) string {
|
||||
var ctx strings.Builder
|
||||
|
||||
if input.Config.UserMemory != "" {
|
||||
ctx.WriteString("## User Preferences\n")
|
||||
ctx.WriteString(input.Config.UserMemory)
|
||||
ctx.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if input.Config.CollectionContext != "" {
|
||||
ctx.WriteString("## Collection Context\n")
|
||||
ctx.WriteString(input.Config.CollectionContext)
|
||||
ctx.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if input.Config.FileContext != "" {
|
||||
ctx.WriteString("## Uploaded Files Content\n")
|
||||
ctx.WriteString(input.Config.FileContext)
|
||||
ctx.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if input.Config.MemoryContext != "" {
|
||||
ctx.WriteString("## Previous Context\n")
|
||||
ctx.WriteString(input.Config.MemoryContext)
|
||||
ctx.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return ctx.String()
|
||||
}
|
||||
|
||||
func fetchPreGeneratedDigest(ctx context.Context, discoverURL, articleURL string) (*DigestResponse, error) {
|
||||
if discoverURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/v1/discover/digest?url=%s",
|
||||
strings.TrimSuffix(discoverURL, "/"),
|
||||
url.QueryEscape(articleURL))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 3 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var digest DigestResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&digest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if digest.SummaryRu != "" && len(digest.Citations) > 0 {
|
||||
return &digest, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func preScrapeArticleURL(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
|
||||
if crawl4aiURL != "" {
|
||||
article, err := scrapeWithCrawl4AI(ctx, crawl4aiURL, articleURL)
|
||||
if err == nil && article != nil {
|
||||
return article, nil
|
||||
}
|
||||
}
|
||||
|
||||
return scrapeDirectly(ctx, articleURL)
|
||||
}
|
||||
|
||||
func scrapeWithCrawl4AI(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
|
||||
reqBody := fmt.Sprintf(`{
|
||||
"urls": ["%s"],
|
||||
"crawler_config": {
|
||||
"type": "CrawlerRunConfig",
|
||||
"params": {
|
||||
"cache_mode": "default",
|
||||
"page_timeout": 20000
|
||||
}
|
||||
}
|
||||
}`, articleURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", crawl4aiURL+"/crawl", strings.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 25 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Crawl4AI returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
markdown := extractMarkdownFromCrawl4AI(string(body))
|
||||
title := extractTitleFromCrawl4AI(string(body))
|
||||
|
||||
if len(markdown) > 100 {
|
||||
content := markdown
|
||||
if len(content) > 15000 {
|
||||
content = content[:15000]
|
||||
}
|
||||
return &PreScrapedArticle{
|
||||
Title: title,
|
||||
Content: content,
|
||||
URL: articleURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("insufficient content from Crawl4AI")
|
||||
}
|
||||
|
||||
func scrapeDirectly(ctx context.Context, articleURL string) (*PreScrapedArticle, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", articleURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "GooSeek-Agent/1.0")
|
||||
req.Header.Set("Accept", "text/html")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
html := string(body)
|
||||
title := extractHTMLTitle(html)
|
||||
content := extractTextContent(html)
|
||||
|
||||
if len(content) < 100 {
|
||||
return nil, fmt.Errorf("insufficient content")
|
||||
}
|
||||
|
||||
if len(content) > 15000 {
|
||||
content = content[:15000]
|
||||
}
|
||||
|
||||
return &PreScrapedArticle{
|
||||
Title: title,
|
||||
Content: content,
|
||||
URL: articleURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
titleRegex = regexp.MustCompile(`<title[^>]*>([^<]+)</title>`)
|
||||
scriptRegex = regexp.MustCompile(`(?s)<script[^>]*>.*?</script>`)
|
||||
styleRegex = regexp.MustCompile(`(?s)<style[^>]*>.*?</style>`)
|
||||
tagRegex = regexp.MustCompile(`<[^>]+>`)
|
||||
spaceRegex = regexp.MustCompile(`\s+`)
|
||||
)
|
||||
|
||||
func extractHTMLTitle(html string) string {
|
||||
matches := titleRegex.FindStringSubmatch(html)
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractTextContent(html string) string {
|
||||
bodyStart := strings.Index(strings.ToLower(html), "<body")
|
||||
bodyEnd := strings.Index(strings.ToLower(html), "</body>")
|
||||
|
||||
if bodyStart != -1 && bodyEnd != -1 && bodyEnd > bodyStart {
|
||||
html = html[bodyStart:bodyEnd]
|
||||
}
|
||||
|
||||
html = scriptRegex.ReplaceAllString(html, "")
|
||||
html = styleRegex.ReplaceAllString(html, "")
|
||||
html = tagRegex.ReplaceAllString(html, " ")
|
||||
html = spaceRegex.ReplaceAllString(html, " ")
|
||||
|
||||
return strings.TrimSpace(html)
|
||||
}
|
||||
|
||||
func extractMarkdownFromCrawl4AI(response string) string {
|
||||
if idx := strings.Index(response, `"raw_markdown"`); idx != -1 {
|
||||
start := idx + len(`"raw_markdown"`)
|
||||
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
|
||||
start += colonIdx + 1
|
||||
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
|
||||
start++
|
||||
}
|
||||
end := strings.Index(response[start:], `"`)
|
||||
if end > 0 {
|
||||
return response[start : start+end]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractTitleFromCrawl4AI(response string) string {
|
||||
if idx := strings.Index(response, `"title"`); idx != -1 {
|
||||
start := idx + len(`"title"`)
|
||||
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
|
||||
start += colonIdx + 1
|
||||
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
|
||||
start++
|
||||
}
|
||||
end := strings.Index(response[start:], `"`)
|
||||
if end > 0 {
|
||||
return response[start : start+end]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func runSpeedMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string) error {
|
||||
classification := fastClassify(input.FollowUp, input.ChatHistory)
|
||||
searchQuery := classification.StandaloneFollowUp
|
||||
if searchQuery == "" {
|
||||
searchQuery = input.FollowUp
|
||||
}
|
||||
queries := generateSearchQueries(searchQuery)
|
||||
|
||||
researchBlockID := uuid.New().String()
|
||||
sess.EmitBlock(types.NewResearchBlock(researchBlockID))
|
||||
|
||||
var searchResults []types.Chunk
|
||||
var mediaResult *search.MediaSearchResult
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
results, err := parallelSearch(gctx, input.Config.SearchClient, queries)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
searchResults = results
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
result, err := input.Config.SearchClient.SearchMedia(gctx, searchQuery, &search.MediaSearchOptions{
|
||||
MaxImages: 6,
|
||||
MaxVideos: 4,
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mediaResult = result
|
||||
return nil
|
||||
})
|
||||
|
||||
_ = g.Wait()
|
||||
|
||||
if len(searchResults) > 0 {
|
||||
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
|
||||
}
|
||||
|
||||
if mediaResult != nil {
|
||||
if len(mediaResult.Images) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
|
||||
"images": mediaResult.Images,
|
||||
"layout": "carousel",
|
||||
}))
|
||||
}
|
||||
if len(mediaResult.Videos) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
|
||||
"items": mediaResult.Videos,
|
||||
"title": "",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
sess.EmitResearchComplete()
|
||||
|
||||
queryComplexity := search.EstimateQueryComplexity(searchQuery)
|
||||
adaptiveTopK := search.ComputeAdaptiveTopK(len(searchResults), queryComplexity, "speed")
|
||||
rankedResults := search.RerankBM25(searchResults, searchQuery, adaptiveTopK)
|
||||
|
||||
finalContext := buildContext(rankedResults, 15, 250)
|
||||
|
||||
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
|
||||
Context: finalContext,
|
||||
SystemInstructions: input.Config.SystemInstructions,
|
||||
Mode: string(input.Config.Mode),
|
||||
Locale: input.Config.Locale,
|
||||
MemoryContext: input.Config.MemoryContext,
|
||||
AnswerMode: input.Config.AnswerMode,
|
||||
DetectedLanguage: detectedLang,
|
||||
IsArticleSummary: false,
|
||||
})
|
||||
|
||||
messages := []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: writerPrompt},
|
||||
}
|
||||
messages = append(messages, input.ChatHistory...)
|
||||
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
|
||||
|
||||
return streamResponse(ctx, sess, input.Config.LLM, messages, 2048, input.FollowUp, input.Config.Locale)
|
||||
}
|
||||
|
||||
func runFullMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string, isArticleSummary bool) error {
|
||||
if input.Config.EnableClarifying && !isArticleSummary && input.Config.Mode == ModeQuality {
|
||||
clarifying, err := generateClarifyingQuestions(ctx, input.Config.LLM, input.FollowUp)
|
||||
if err == nil && len(clarifying) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "clarifying", map[string]interface{}{
|
||||
"questions": clarifying,
|
||||
"query": input.FollowUp,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
enhancedContext := buildEnhancedContext(input)
|
||||
if enhancedContext != "" {
|
||||
input.Config.MemoryContext = enhancedContext + input.Config.MemoryContext
|
||||
}
|
||||
|
||||
var preScrapedArticle *PreScrapedArticle
|
||||
var articleURL string
|
||||
|
||||
if isArticleSummary {
|
||||
articleURL = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(input.FollowUp), "Summary: "))
|
||||
|
||||
digestCtx, digestCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
scrapeCtx, scrapeCancel := context.WithTimeout(ctx, 25*time.Second)
|
||||
|
||||
digestCh := make(chan *DigestResponse, 1)
|
||||
scrapeCh := make(chan *PreScrapedArticle, 1)
|
||||
|
||||
go func() {
|
||||
defer digestCancel()
|
||||
digest, _ := fetchPreGeneratedDigest(digestCtx, input.Config.DiscoverSvcURL, articleURL)
|
||||
digestCh <- digest
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer scrapeCancel()
|
||||
article, _ := preScrapeArticleURL(scrapeCtx, input.Config.Crawl4AIURL, articleURL)
|
||||
scrapeCh <- article
|
||||
}()
|
||||
|
||||
digest := <-digestCh
|
||||
preScrapedArticle = <-scrapeCh
|
||||
|
||||
if digest != nil {
|
||||
chunks := make([]types.Chunk, len(digest.Citations))
|
||||
for i, c := range digest.Citations {
|
||||
chunks[i] = types.Chunk{
|
||||
Content: c.Title,
|
||||
Metadata: map[string]string{
|
||||
"url": c.URL,
|
||||
"title": c.Title,
|
||||
"domain": c.Domain,
|
||||
},
|
||||
}
|
||||
}
|
||||
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), chunks))
|
||||
sess.EmitResearchComplete()
|
||||
|
||||
summaryText := digest.SummaryRu
|
||||
if len(digest.FollowUp) > 0 {
|
||||
summaryText += "\n\n---\n"
|
||||
for _, q := range digest.FollowUp {
|
||||
summaryText += "> " + q + "\n"
|
||||
}
|
||||
}
|
||||
sess.EmitBlock(types.NewTextBlock(uuid.New().String(), summaryText))
|
||||
sess.EmitEnd()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
classification, err := classify(ctx, input.Config.LLM, input.FollowUp, input.ChatHistory, input.Config.Locale, detectedLang)
|
||||
if err != nil {
|
||||
classification = &ClassificationResult{
|
||||
StandaloneFollowUp: input.FollowUp,
|
||||
SkipSearch: false,
|
||||
}
|
||||
}
|
||||
|
||||
if isArticleSummary && classification.SkipSearch {
|
||||
classification.SkipSearch = false
|
||||
}
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
|
||||
var searchResults []types.Chunk
|
||||
var mediaResult *search.MediaSearchResult
|
||||
|
||||
mediaQuery := classification.StandaloneFollowUp
|
||||
if mediaQuery == "" {
|
||||
mediaQuery = input.FollowUp
|
||||
}
|
||||
|
||||
effectiveFollowUp := input.FollowUp
|
||||
if isArticleSummary && preScrapedArticle != nil && preScrapedArticle.Title != "" {
|
||||
effectiveFollowUp = fmt.Sprintf("Summary: %s\nArticle title: %s", preScrapedArticle.URL, preScrapedArticle.Title)
|
||||
if classification.StandaloneFollowUp != "" {
|
||||
classification.StandaloneFollowUp = preScrapedArticle.Title + " " + classification.StandaloneFollowUp
|
||||
} else {
|
||||
classification.StandaloneFollowUp = preScrapedArticle.Title
|
||||
}
|
||||
}
|
||||
|
||||
if !classification.SkipSearch {
|
||||
g.Go(func() error {
|
||||
results, err := research(gctx, sess, input.Config.LLM, input.Config.SearchClient, ResearchInput{
|
||||
ChatHistory: input.ChatHistory,
|
||||
FollowUp: effectiveFollowUp,
|
||||
Classification: classification,
|
||||
Mode: input.Config.Mode,
|
||||
Sources: input.Config.Sources,
|
||||
Locale: input.Config.Locale,
|
||||
DetectedLang: detectedLang,
|
||||
IsArticleSummary: isArticleSummary,
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
searchResults = results
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if !isArticleSummary {
|
||||
g.Go(func() error {
|
||||
result, err := input.Config.SearchClient.SearchMedia(gctx, mediaQuery, &search.MediaSearchOptions{
|
||||
MaxImages: 8,
|
||||
MaxVideos: 6,
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mediaResult = result
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = g.Wait()
|
||||
|
||||
if isArticleSummary && preScrapedArticle != nil {
|
||||
alreadyHasURL := false
|
||||
for _, r := range searchResults {
|
||||
if strings.Contains(r.Metadata["url"], preScrapedArticle.URL) {
|
||||
alreadyHasURL = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !alreadyHasURL {
|
||||
prependChunk := types.Chunk{
|
||||
Content: preScrapedArticle.Content,
|
||||
Metadata: map[string]string{
|
||||
"url": preScrapedArticle.URL,
|
||||
"title": preScrapedArticle.Title,
|
||||
},
|
||||
}
|
||||
searchResults = append([]types.Chunk{prependChunk}, searchResults...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(searchResults) > 0 {
|
||||
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
|
||||
}
|
||||
|
||||
if mediaResult != nil {
|
||||
if len(mediaResult.Images) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
|
||||
"images": mediaResult.Images,
|
||||
"layout": "carousel",
|
||||
}))
|
||||
}
|
||||
if len(mediaResult.Videos) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
|
||||
"items": mediaResult.Videos,
|
||||
"title": "",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
sess.EmitResearchComplete()
|
||||
|
||||
maxResults := 25
|
||||
maxContent := 320
|
||||
if isArticleSummary {
|
||||
maxResults = 30
|
||||
maxContent = 2000
|
||||
}
|
||||
|
||||
rankedResults := rankByRelevance(searchResults, input.FollowUp)
|
||||
if len(rankedResults) > maxResults {
|
||||
rankedResults = rankedResults[:maxResults]
|
||||
}
|
||||
|
||||
finalContext := buildContext(rankedResults, maxResults, maxContent)
|
||||
|
||||
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
|
||||
Context: finalContext,
|
||||
SystemInstructions: input.Config.SystemInstructions,
|
||||
Mode: string(input.Config.Mode),
|
||||
Locale: input.Config.Locale,
|
||||
MemoryContext: input.Config.MemoryContext,
|
||||
AnswerMode: input.Config.AnswerMode,
|
||||
DetectedLanguage: detectedLang,
|
||||
IsArticleSummary: isArticleSummary,
|
||||
})
|
||||
|
||||
messages := []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: writerPrompt},
|
||||
}
|
||||
messages = append(messages, input.ChatHistory...)
|
||||
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
|
||||
|
||||
maxTokens := 4096
|
||||
return streamResponse(ctx, sess, input.Config.LLM, messages, maxTokens, input.FollowUp, input.Config.Locale)
|
||||
}
|
||||
|
||||
func streamResponse(ctx context.Context, sess *session.Session, client llm.Client, messages []llm.Message, maxTokens int, query string, locale string) error {
|
||||
stream, err := client.StreamText(ctx, llm.StreamRequest{
|
||||
Messages: messages,
|
||||
Options: llm.StreamOptions{MaxTokens: maxTokens},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var responseBlockID string
|
||||
var accumulatedText string
|
||||
|
||||
for chunk := range stream {
|
||||
if chunk.ContentChunk == "" && responseBlockID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBlockID == "" {
|
||||
responseBlockID = uuid.New().String()
|
||||
accumulatedText = chunk.ContentChunk
|
||||
sess.EmitBlock(types.NewTextBlock(responseBlockID, accumulatedText))
|
||||
} else if chunk.ContentChunk != "" {
|
||||
accumulatedText += chunk.ContentChunk
|
||||
sess.EmitTextChunk(responseBlockID, chunk.ContentChunk)
|
||||
}
|
||||
}
|
||||
|
||||
if responseBlockID != "" {
|
||||
sess.UpdateBlock(responseBlockID, []session.Patch{
|
||||
{Op: "replace", Path: "/data", Value: accumulatedText},
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
relatedCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
related := generateRelatedQuestions(relatedCtx, client, query, accumulatedText, locale)
|
||||
if len(related) > 0 {
|
||||
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
|
||||
"questions": related,
|
||||
}))
|
||||
}
|
||||
}()
|
||||
|
||||
sess.EmitEnd()
|
||||
return nil
|
||||
}
|
||||
|
||||
func parallelSearch(ctx context.Context, client *search.SearXNGClient, queries []string) ([]types.Chunk, error) {
|
||||
results := make([]types.Chunk, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
resultsCh := make(chan []types.SearchResult, len(queries))
|
||||
|
||||
for _, q := range queries {
|
||||
query := q
|
||||
g.Go(func() error {
|
||||
resp, err := client.Search(gctx, query, &search.SearchOptions{
|
||||
Categories: []string{"general", "news"},
|
||||
PageNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
resultsCh <- nil
|
||||
return nil
|
||||
}
|
||||
resultsCh <- resp.Results
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
g.Wait()
|
||||
close(resultsCh)
|
||||
}()
|
||||
|
||||
for batch := range resultsCh {
|
||||
for _, r := range batch {
|
||||
if r.URL != "" && !seen[r.URL] {
|
||||
seen[r.URL] = true
|
||||
results = append(results, r.ToChunk())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func buildContext(chunks []types.Chunk, maxResults, maxContentLen int) string {
|
||||
if len(chunks) > maxResults {
|
||||
chunks = chunks[:maxResults]
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<search_results note=\"These are the search results and assistant can cite these\">\n")
|
||||
|
||||
for i, chunk := range chunks {
|
||||
content := chunk.Content
|
||||
if len(content) > maxContentLen {
|
||||
content = content[:maxContentLen] + "…"
|
||||
}
|
||||
title := chunk.Metadata["title"]
|
||||
sb.WriteString("<result index=")
|
||||
sb.WriteString(strings.ReplaceAll(title, "\"", "'"))
|
||||
sb.WriteString("\" index=\"")
|
||||
sb.WriteString(string(rune('0' + i + 1)))
|
||||
sb.WriteString("\">")
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("</result>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</search_results>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func rankByRelevance(chunks []types.Chunk, query string) []types.Chunk {
|
||||
if len(chunks) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
terms := extractQueryTerms(query)
|
||||
if len(terms) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
type scored struct {
|
||||
chunk types.Chunk
|
||||
score int
|
||||
}
|
||||
|
||||
scored_chunks := make([]scored, len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
score := 0
|
||||
content := strings.ToLower(chunk.Content)
|
||||
title := strings.ToLower(chunk.Metadata["title"])
|
||||
|
||||
for term := range terms {
|
||||
if strings.Contains(title, term) {
|
||||
score += 3
|
||||
}
|
||||
if strings.Contains(content, term) {
|
||||
score += 1
|
||||
}
|
||||
}
|
||||
|
||||
scored_chunks[i] = scored{chunk: chunk, score: score}
|
||||
}
|
||||
|
||||
for i := 0; i < len(scored_chunks)-1; i++ {
|
||||
for j := i + 1; j < len(scored_chunks); j++ {
|
||||
if scored_chunks[j].score > scored_chunks[i].score {
|
||||
scored_chunks[i], scored_chunks[j] = scored_chunks[j], scored_chunks[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]types.Chunk, len(scored_chunks))
|
||||
for i, s := range scored_chunks {
|
||||
result[i] = s.chunk
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func extractQueryTerms(query string) map[string]bool {
|
||||
query = strings.ToLower(query)
|
||||
query = strings.TrimPrefix(query, "summary: ")
|
||||
|
||||
words := strings.Fields(query)
|
||||
terms := make(map[string]bool)
|
||||
|
||||
for _, w := range words {
|
||||
if len(w) >= 2 && !strings.HasPrefix(w, "http") {
|
||||
terms[w] = true
|
||||
}
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
128
backend/internal/agent/researcher.go
Normal file
128
backend/internal/agent/researcher.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/search"
|
||||
"github.com/gooseek/backend/internal/session"
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ResearchInput struct {
|
||||
ChatHistory []llm.Message
|
||||
FollowUp string
|
||||
Classification *ClassificationResult
|
||||
Mode Mode
|
||||
Sources []string
|
||||
Locale string
|
||||
DetectedLang string
|
||||
IsArticleSummary bool
|
||||
}
|
||||
|
||||
func research(
|
||||
ctx context.Context,
|
||||
sess *session.Session,
|
||||
llmClient llm.Client,
|
||||
searchClient *search.SearXNGClient,
|
||||
input ResearchInput,
|
||||
) ([]types.Chunk, error) {
|
||||
maxIterations := 1
|
||||
switch input.Mode {
|
||||
case ModeBalanced:
|
||||
maxIterations = 3
|
||||
case ModeQuality:
|
||||
maxIterations = 10
|
||||
}
|
||||
|
||||
researchBlockID := uuid.New().String()
|
||||
sess.EmitBlock(types.NewResearchBlock(researchBlockID))
|
||||
|
||||
allResults := make([]types.Chunk, 0)
|
||||
seenURLs := make(map[string]bool)
|
||||
|
||||
searchQuery := input.Classification.StandaloneFollowUp
|
||||
if searchQuery == "" {
|
||||
searchQuery = input.FollowUp
|
||||
}
|
||||
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
queries := generateSearchQueries(searchQuery)
|
||||
|
||||
sess.UpdateBlock(researchBlockID, []session.Patch{
|
||||
{
|
||||
Op: "replace",
|
||||
Path: "/data/subSteps",
|
||||
Value: []types.ResearchSubStep{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
Type: "searching",
|
||||
Searching: queries,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
for _, q := range queries {
|
||||
resp, err := searchClient.Search(ctx, q, &search.SearchOptions{
|
||||
Categories: categoriesToSearch(input.Sources),
|
||||
PageNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, r := range resp.Results {
|
||||
if r.URL != "" && !seenURLs[r.URL] {
|
||||
seenURLs[r.URL] = true
|
||||
allResults = append(allResults, r.ToChunk())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.Mode == ModeSpeed {
|
||||
break
|
||||
}
|
||||
|
||||
if len(allResults) >= 20 && input.Mode == ModeBalanced {
|
||||
break
|
||||
}
|
||||
|
||||
if len(allResults) >= 50 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return allResults, nil
|
||||
}
|
||||
|
||||
func categoriesToSearch(sources []string) []string {
|
||||
if len(sources) == 0 {
|
||||
return []string{"general", "news"}
|
||||
}
|
||||
|
||||
categories := make([]string, 0)
|
||||
for _, s := range sources {
|
||||
switch s {
|
||||
case "web":
|
||||
categories = append(categories, "general")
|
||||
case "discussions":
|
||||
categories = append(categories, "social media")
|
||||
case "academic":
|
||||
categories = append(categories, "science")
|
||||
case "news":
|
||||
categories = append(categories, "news")
|
||||
case "images":
|
||||
categories = append(categories, "images")
|
||||
case "videos":
|
||||
categories = append(categories, "videos")
|
||||
}
|
||||
}
|
||||
|
||||
if len(categories) == 0 {
|
||||
return []string{"general"}
|
||||
}
|
||||
|
||||
return categories
|
||||
}
|
||||
587
backend/internal/computer/browser/browser.go
Normal file
587
backend/internal/computer/browser/browser.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package browser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type PlaywrightBrowser struct {
|
||||
cmd *exec.Cmd
|
||||
serverURL string
|
||||
client *http.Client
|
||||
sessions map[string]*BrowserSession
|
||||
mu sync.RWMutex
|
||||
config Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
PlaywrightServerURL string
|
||||
DefaultTimeout time.Duration
|
||||
Headless bool
|
||||
UserAgent string
|
||||
ProxyURL string
|
||||
ScreenshotsDir string
|
||||
RecordingsDir string
|
||||
}
|
||||
|
||||
type BrowserSession struct {
|
||||
ID string
|
||||
ContextID string
|
||||
PageID string
|
||||
CreatedAt time.Time
|
||||
LastAction time.Time
|
||||
Screenshots []string
|
||||
Recordings []string
|
||||
Closed bool
|
||||
}
|
||||
|
||||
type ActionRequest struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Action string `json:"action"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type ActionResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Screenshot string `json:"screenshot,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
PageTitle string `json:"pageTitle,omitempty"`
|
||||
PageURL string `json:"pageUrl,omitempty"`
|
||||
}
|
||||
|
||||
func NewPlaywrightBrowser(cfg Config) *PlaywrightBrowser {
|
||||
if cfg.DefaultTimeout == 0 {
|
||||
cfg.DefaultTimeout = 30 * time.Second
|
||||
}
|
||||
if cfg.PlaywrightServerURL == "" {
|
||||
cfg.PlaywrightServerURL = "http://localhost:3050"
|
||||
}
|
||||
if cfg.ScreenshotsDir == "" {
|
||||
cfg.ScreenshotsDir = "/tmp/gooseek-screenshots"
|
||||
}
|
||||
if cfg.RecordingsDir == "" {
|
||||
cfg.RecordingsDir = "/tmp/gooseek-recordings"
|
||||
}
|
||||
|
||||
os.MkdirAll(cfg.ScreenshotsDir, 0755)
|
||||
os.MkdirAll(cfg.RecordingsDir, 0755)
|
||||
|
||||
return &PlaywrightBrowser{
|
||||
serverURL: cfg.PlaywrightServerURL,
|
||||
client: &http.Client{
|
||||
Timeout: cfg.DefaultTimeout,
|
||||
},
|
||||
sessions: make(map[string]*BrowserSession),
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) NewSession(ctx context.Context, opts SessionOptions) (*BrowserSession, error) {
|
||||
sessionID := uuid.New().String()
|
||||
|
||||
params := map[string]interface{}{
|
||||
"headless": b.config.Headless,
|
||||
"sessionId": sessionID,
|
||||
}
|
||||
|
||||
if opts.Viewport != nil {
|
||||
params["viewport"] = opts.Viewport
|
||||
}
|
||||
if opts.UserAgent != "" {
|
||||
params["userAgent"] = opts.UserAgent
|
||||
} else if b.config.UserAgent != "" {
|
||||
params["userAgent"] = b.config.UserAgent
|
||||
}
|
||||
if opts.ProxyURL != "" {
|
||||
params["proxy"] = opts.ProxyURL
|
||||
} else if b.config.ProxyURL != "" {
|
||||
params["proxy"] = b.config.ProxyURL
|
||||
}
|
||||
if opts.RecordVideo {
|
||||
params["recordVideo"] = map[string]interface{}{
|
||||
"dir": b.config.RecordingsDir,
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "browser.newContext", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create browser context: %w", err)
|
||||
}
|
||||
|
||||
contextID, _ := resp["contextId"].(string)
|
||||
pageID, _ := resp["pageId"].(string)
|
||||
|
||||
session := &BrowserSession{
|
||||
ID: sessionID,
|
||||
ContextID: contextID,
|
||||
PageID: pageID,
|
||||
CreatedAt: time.Now(),
|
||||
LastAction: time.Now(),
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.sessions[sessionID] = session
|
||||
b.mu.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) CloseSession(ctx context.Context, sessionID string) error {
|
||||
b.mu.Lock()
|
||||
session, ok := b.sessions[sessionID]
|
||||
if !ok {
|
||||
b.mu.Unlock()
|
||||
return errors.New("session not found")
|
||||
}
|
||||
session.Closed = true
|
||||
delete(b.sessions, sessionID)
|
||||
b.mu.Unlock()
|
||||
|
||||
_, err := b.sendCommand(ctx, "browser.closeContext", map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Navigate(ctx context.Context, sessionID, url string, opts NavigateOptions) (*ActionResponse, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"url": url,
|
||||
}
|
||||
if opts.Timeout > 0 {
|
||||
params["timeout"] = opts.Timeout
|
||||
}
|
||||
if opts.WaitUntil != "" {
|
||||
params["waitUntil"] = opts.WaitUntil
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.goto", params)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
result := &ActionResponse{
|
||||
Success: true,
|
||||
PageURL: getString(resp, "url"),
|
||||
PageTitle: getString(resp, "title"),
|
||||
}
|
||||
|
||||
if opts.Screenshot {
|
||||
screenshot, _ := b.Screenshot(ctx, sessionID, ScreenshotOptions{FullPage: false})
|
||||
if screenshot != nil {
|
||||
result.Screenshot = screenshot.Data
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Click(ctx context.Context, sessionID, selector string, opts ClickOptions) (*ActionResponse, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
}
|
||||
if opts.Button != "" {
|
||||
params["button"] = opts.Button
|
||||
}
|
||||
if opts.ClickCount > 0 {
|
||||
params["clickCount"] = opts.ClickCount
|
||||
}
|
||||
if opts.Timeout > 0 {
|
||||
params["timeout"] = opts.Timeout
|
||||
}
|
||||
if opts.Force {
|
||||
params["force"] = true
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.click", params)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
result := &ActionResponse{Success: true}
|
||||
|
||||
if opts.WaitAfter > 0 {
|
||||
time.Sleep(time.Duration(opts.WaitAfter) * time.Millisecond)
|
||||
}
|
||||
|
||||
if opts.Screenshot {
|
||||
screenshot, _ := b.Screenshot(ctx, sessionID, ScreenshotOptions{FullPage: false})
|
||||
if screenshot != nil {
|
||||
result.Screenshot = screenshot.Data
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Type(ctx context.Context, sessionID, selector, text string, opts TypeOptions) (*ActionResponse, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
"text": text,
|
||||
}
|
||||
if opts.Delay > 0 {
|
||||
params["delay"] = opts.Delay
|
||||
}
|
||||
if opts.Timeout > 0 {
|
||||
params["timeout"] = opts.Timeout
|
||||
}
|
||||
if opts.Clear {
|
||||
b.sendCommand(ctx, "page.fill", map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
"value": "",
|
||||
})
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.type", params)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
return &ActionResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Fill(ctx context.Context, sessionID, selector, value string) (*ActionResponse, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
"value": value,
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.fill", params)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
return &ActionResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Screenshot(ctx context.Context, sessionID string, opts ScreenshotOptions) (*ScreenshotResult, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"fullPage": opts.FullPage,
|
||||
}
|
||||
if opts.Selector != "" {
|
||||
params["selector"] = opts.Selector
|
||||
}
|
||||
if opts.Quality > 0 {
|
||||
params["quality"] = opts.Quality
|
||||
}
|
||||
params["type"] = "png"
|
||||
if opts.Format != "" {
|
||||
params["type"] = opts.Format
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.screenshot", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(string)
|
||||
|
||||
filename := fmt.Sprintf("%s/%s-%d.png", b.config.ScreenshotsDir, sessionID, time.Now().UnixNano())
|
||||
if decoded, err := base64.StdEncoding.DecodeString(data); err == nil {
|
||||
os.WriteFile(filename, decoded, 0644)
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
if session, ok := b.sessions[sessionID]; ok {
|
||||
session.Screenshots = append(session.Screenshots, filename)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
return &ScreenshotResult{
|
||||
Data: data,
|
||||
Path: filename,
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) ExtractText(ctx context.Context, sessionID, selector string) (string, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.textContent", params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return getString(resp, "text"), nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) ExtractHTML(ctx context.Context, sessionID, selector string) (string, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.innerHTML", params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return getString(resp, "html"), nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) WaitForSelector(ctx context.Context, sessionID, selector string, opts WaitOptions) error {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
}
|
||||
if opts.Timeout > 0 {
|
||||
params["timeout"] = opts.Timeout
|
||||
}
|
||||
if opts.State != "" {
|
||||
params["state"] = opts.State
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.waitForSelector", params)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) WaitForNavigation(ctx context.Context, sessionID string, opts WaitOptions) error {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
}
|
||||
if opts.Timeout > 0 {
|
||||
params["timeout"] = opts.Timeout
|
||||
}
|
||||
if opts.WaitUntil != "" {
|
||||
params["waitUntil"] = opts.WaitUntil
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.waitForNavigation", params)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Scroll(ctx context.Context, sessionID string, opts ScrollOptions) (*ActionResponse, error) {
|
||||
script := fmt.Sprintf("window.scrollBy(%d, %d)", opts.X, opts.Y)
|
||||
if opts.Selector != "" {
|
||||
script = fmt.Sprintf(`document.querySelector('%s').scrollBy(%d, %d)`, opts.Selector, opts.X, opts.Y)
|
||||
}
|
||||
if opts.ToBottom {
|
||||
script = "window.scrollTo(0, document.body.scrollHeight)"
|
||||
}
|
||||
if opts.ToTop {
|
||||
script = "window.scrollTo(0, 0)"
|
||||
}
|
||||
|
||||
_, err := b.Evaluate(ctx, sessionID, script)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
if opts.WaitAfter > 0 {
|
||||
time.Sleep(time.Duration(opts.WaitAfter) * time.Millisecond)
|
||||
}
|
||||
|
||||
return &ActionResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Evaluate(ctx context.Context, sessionID, script string) (interface{}, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"expression": script,
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.evaluate", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp["result"], nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) Select(ctx context.Context, sessionID, selector string, values []string) (*ActionResponse, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
"selector": selector,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
_, err := b.sendCommand(ctx, "page.selectOption", params)
|
||||
if err != nil {
|
||||
return &ActionResponse{Success: false, Error: err.Error()}, err
|
||||
}
|
||||
|
||||
return &ActionResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) GetPageInfo(ctx context.Context, sessionID string) (*PageInfo, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.info", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PageInfo{
|
||||
URL: getString(resp, "url"),
|
||||
Title: getString(resp, "title"),
|
||||
Content: getString(resp, "content"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) PDF(ctx context.Context, sessionID string, opts PDFOptions) ([]byte, error) {
|
||||
params := map[string]interface{}{
|
||||
"sessionId": sessionID,
|
||||
}
|
||||
if opts.Format != "" {
|
||||
params["format"] = opts.Format
|
||||
}
|
||||
if opts.Landscape {
|
||||
params["landscape"] = true
|
||||
}
|
||||
if opts.PrintBackground {
|
||||
params["printBackground"] = true
|
||||
}
|
||||
|
||||
resp, err := b.sendCommand(ctx, "page.pdf", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(string)
|
||||
return base64.StdEncoding.DecodeString(data)
|
||||
}
|
||||
|
||||
func (b *PlaywrightBrowser) sendCommand(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
|
||||
body := map[string]interface{}{
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.serverURL+"/api/browser", strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := b.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errMsg, ok := result["error"].(string); ok && errMsg != "" {
|
||||
return result, errors.New(errMsg)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type SessionOptions struct {
|
||||
Headless bool
|
||||
Viewport *Viewport
|
||||
UserAgent string
|
||||
ProxyURL string
|
||||
RecordVideo bool
|
||||
BlockAds bool
|
||||
}
|
||||
|
||||
type Viewport struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
type NavigateOptions struct {
|
||||
Timeout int
|
||||
WaitUntil string
|
||||
Screenshot bool
|
||||
}
|
||||
|
||||
type ClickOptions struct {
|
||||
Button string
|
||||
ClickCount int
|
||||
Timeout int
|
||||
Force bool
|
||||
WaitAfter int
|
||||
Screenshot bool
|
||||
}
|
||||
|
||||
type TypeOptions struct {
|
||||
Delay int
|
||||
Timeout int
|
||||
Clear bool
|
||||
}
|
||||
|
||||
type ScreenshotOptions struct {
|
||||
FullPage bool
|
||||
Selector string
|
||||
Format string
|
||||
Quality int
|
||||
}
|
||||
|
||||
type ScreenshotResult struct {
|
||||
Data string
|
||||
Path string
|
||||
MimeType string
|
||||
}
|
||||
|
||||
type WaitOptions struct {
|
||||
Timeout int
|
||||
State string
|
||||
WaitUntil string
|
||||
}
|
||||
|
||||
type ScrollOptions struct {
|
||||
X int
|
||||
Y int
|
||||
Selector string
|
||||
ToBottom bool
|
||||
ToTop bool
|
||||
WaitAfter int
|
||||
}
|
||||
|
||||
type PageInfo struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
}
|
||||
|
||||
type PDFOptions struct {
|
||||
Format string
|
||||
Landscape bool
|
||||
PrintBackground bool
|
||||
}
|
||||
555
backend/internal/computer/browser/server.go
Normal file
555
backend/internal/computer/browser/server.go
Normal file
@@ -0,0 +1,555 @@
|
||||
package browser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
)
|
||||
|
||||
type BrowserServer struct {
|
||||
browser *PlaywrightBrowser
|
||||
sessions map[string]*ManagedSession
|
||||
mu sync.RWMutex
|
||||
config ServerConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int
|
||||
MaxSessions int
|
||||
SessionTimeout time.Duration
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
type ManagedSession struct {
|
||||
*BrowserSession
|
||||
LastActive time.Time
|
||||
Actions []ActionLog
|
||||
}
|
||||
|
||||
type ActionLog struct {
|
||||
Action string `json:"action"`
|
||||
Params string `json:"params"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration int64 `json:"durationMs"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type BrowserRequest struct {
|
||||
Method string `json:"method"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
func NewBrowserServer(cfg ServerConfig) *BrowserServer {
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 3050
|
||||
}
|
||||
if cfg.MaxSessions == 0 {
|
||||
cfg.MaxSessions = 20
|
||||
}
|
||||
if cfg.SessionTimeout == 0 {
|
||||
cfg.SessionTimeout = 30 * time.Minute
|
||||
}
|
||||
if cfg.CleanupInterval == 0 {
|
||||
cfg.CleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &BrowserServer{
|
||||
browser: NewPlaywrightBrowser(Config{
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
Headless: true,
|
||||
}),
|
||||
sessions: make(map[string]*ManagedSession),
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BrowserServer) Start(ctx context.Context) error {
|
||||
go s.cleanupLoop(ctx)
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 50 * 1024 * 1024,
|
||||
ReadTimeout: 2 * time.Minute,
|
||||
WriteTimeout: 2 * time.Minute,
|
||||
})
|
||||
|
||||
app.Use(logger.New())
|
||||
app.Use(cors.New())
|
||||
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{"status": "ok", "sessions": len(s.sessions)})
|
||||
})
|
||||
|
||||
app.Post("/api/browser", s.handleBrowserCommand)
|
||||
|
||||
app.Post("/api/session/new", s.handleNewSession)
|
||||
app.Delete("/api/session/:id", s.handleCloseSession)
|
||||
app.Get("/api/session/:id", s.handleGetSession)
|
||||
app.Get("/api/sessions", s.handleListSessions)
|
||||
|
||||
app.Post("/api/action", s.handleAction)
|
||||
|
||||
log.Printf("[BrowserServer] Starting on port %d", s.config.Port)
|
||||
return app.Listen(fmt.Sprintf(":%d", s.config.Port))
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleBrowserCommand(c *fiber.Ctx) error {
|
||||
var req BrowserRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sessionID, _ := req.Params["sessionId"].(string)
|
||||
|
||||
s.mu.Lock()
|
||||
if session, ok := s.sessions[sessionID]; ok {
|
||||
session.LastActive = time.Now()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
result, err := s.executeMethod(ctx, req.Method, req.Params)
|
||||
|
||||
s.mu.Lock()
|
||||
if session, ok := s.sessions[sessionID]; ok {
|
||||
paramsJSON, _ := json.Marshal(req.Params)
|
||||
session.Actions = append(session.Actions, ActionLog{
|
||||
Action: req.Method,
|
||||
Params: string(paramsJSON),
|
||||
Success: err == nil,
|
||||
Error: errToString(err),
|
||||
Duration: time.Since(start).Milliseconds(),
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return c.JSON(fiber.Map{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (s *BrowserServer) executeMethod(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
|
||||
sessionID, _ := params["sessionId"].(string)
|
||||
|
||||
switch method {
|
||||
case "browser.newContext":
|
||||
opts := SessionOptions{
|
||||
Headless: getBool(params, "headless"),
|
||||
}
|
||||
if viewport, ok := params["viewport"].(map[string]interface{}); ok {
|
||||
opts.Viewport = &Viewport{
|
||||
Width: getInt(viewport, "width"),
|
||||
Height: getInt(viewport, "height"),
|
||||
}
|
||||
}
|
||||
if ua, ok := params["userAgent"].(string); ok {
|
||||
opts.UserAgent = ua
|
||||
}
|
||||
if proxy, ok := params["proxy"].(string); ok {
|
||||
opts.ProxyURL = proxy
|
||||
}
|
||||
if rv, ok := params["recordVideo"].(map[string]interface{}); ok {
|
||||
_ = rv
|
||||
opts.RecordVideo = true
|
||||
}
|
||||
|
||||
session, err := s.browser.NewSession(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[session.ID] = &ManagedSession{
|
||||
BrowserSession: session,
|
||||
LastActive: time.Now(),
|
||||
Actions: make([]ActionLog, 0),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"sessionId": session.ID,
|
||||
"contextId": session.ContextID,
|
||||
"pageId": session.PageID,
|
||||
}, nil
|
||||
|
||||
case "browser.closeContext":
|
||||
err := s.browser.CloseSession(ctx, sessionID)
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, sessionID)
|
||||
s.mu.Unlock()
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.goto":
|
||||
url, _ := params["url"].(string)
|
||||
opts := NavigateOptions{
|
||||
Timeout: getInt(params, "timeout"),
|
||||
WaitUntil: getString(params, "waitUntil"),
|
||||
}
|
||||
result, err := s.browser.Navigate(ctx, sessionID, url, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"success": result.Success,
|
||||
"url": result.PageURL,
|
||||
"title": result.PageTitle,
|
||||
}, nil
|
||||
|
||||
case "page.click":
|
||||
selector, _ := params["selector"].(string)
|
||||
opts := ClickOptions{
|
||||
Button: getString(params, "button"),
|
||||
ClickCount: getInt(params, "clickCount"),
|
||||
Timeout: getInt(params, "timeout"),
|
||||
Force: getBool(params, "force"),
|
||||
}
|
||||
result, err := s.browser.Click(ctx, sessionID, selector, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"success": result.Success,
|
||||
"screenshot": result.Screenshot,
|
||||
}, nil
|
||||
|
||||
case "page.type":
|
||||
selector, _ := params["selector"].(string)
|
||||
text, _ := params["text"].(string)
|
||||
opts := TypeOptions{
|
||||
Delay: getInt(params, "delay"),
|
||||
Timeout: getInt(params, "timeout"),
|
||||
}
|
||||
_, err := s.browser.Type(ctx, sessionID, selector, text, opts)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.fill":
|
||||
selector, _ := params["selector"].(string)
|
||||
value, _ := params["value"].(string)
|
||||
_, err := s.browser.Fill(ctx, sessionID, selector, value)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.screenshot":
|
||||
opts := ScreenshotOptions{
|
||||
FullPage: getBool(params, "fullPage"),
|
||||
Selector: getString(params, "selector"),
|
||||
Format: getString(params, "type"),
|
||||
Quality: getInt(params, "quality"),
|
||||
}
|
||||
result, err := s.browser.Screenshot(ctx, sessionID, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"data": result.Data,
|
||||
"path": result.Path,
|
||||
}, nil
|
||||
|
||||
case "page.textContent":
|
||||
selector, _ := params["selector"].(string)
|
||||
text, err := s.browser.ExtractText(ctx, sessionID, selector)
|
||||
return map[string]interface{}{"text": text}, err
|
||||
|
||||
case "page.innerHTML":
|
||||
selector, _ := params["selector"].(string)
|
||||
html, err := s.browser.ExtractHTML(ctx, sessionID, selector)
|
||||
return map[string]interface{}{"html": html}, err
|
||||
|
||||
case "page.waitForSelector":
|
||||
selector, _ := params["selector"].(string)
|
||||
opts := WaitOptions{
|
||||
Timeout: getInt(params, "timeout"),
|
||||
State: getString(params, "state"),
|
||||
}
|
||||
err := s.browser.WaitForSelector(ctx, sessionID, selector, opts)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.waitForNavigation":
|
||||
opts := WaitOptions{
|
||||
Timeout: getInt(params, "timeout"),
|
||||
WaitUntil: getString(params, "waitUntil"),
|
||||
}
|
||||
err := s.browser.WaitForNavigation(ctx, sessionID, opts)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.evaluate":
|
||||
expression, _ := params["expression"].(string)
|
||||
result, err := s.browser.Evaluate(ctx, sessionID, expression)
|
||||
return map[string]interface{}{"result": result}, err
|
||||
|
||||
case "page.selectOption":
|
||||
selector, _ := params["selector"].(string)
|
||||
values := getStringArray(params, "values")
|
||||
_, err := s.browser.Select(ctx, sessionID, selector, values)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
case "page.info":
|
||||
info, err := s.browser.GetPageInfo(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"url": info.URL,
|
||||
"title": info.Title,
|
||||
"content": info.Content,
|
||||
}, nil
|
||||
|
||||
case "page.pdf":
|
||||
opts := PDFOptions{
|
||||
Format: getString(params, "format"),
|
||||
Landscape: getBool(params, "landscape"),
|
||||
PrintBackground: getBool(params, "printBackground"),
|
||||
}
|
||||
data, err := s.browser.PDF(ctx, sessionID, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"data": data,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown method: %s", method)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleNewSession(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
Headless bool `json:"headless"`
|
||||
Viewport *Viewport `json:"viewport,omitempty"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
ProxyURL string `json:"proxyUrl,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
req.Headless = true
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
if len(s.sessions) >= s.config.MaxSessions {
|
||||
s.mu.RUnlock()
|
||||
return c.Status(http.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "Maximum sessions limit reached",
|
||||
})
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
session, err := s.browser.NewSession(ctx, SessionOptions{
|
||||
Headless: req.Headless,
|
||||
Viewport: req.Viewport,
|
||||
UserAgent: req.UserAgent,
|
||||
ProxyURL: req.ProxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[session.ID] = &ManagedSession{
|
||||
BrowserSession: session,
|
||||
LastActive: time.Now(),
|
||||
Actions: make([]ActionLog, 0),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"sessionId": session.ID,
|
||||
"contextId": session.ContextID,
|
||||
"pageId": session.PageID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleCloseSession(c *fiber.Ctx) error {
|
||||
sessionID := c.Params("id")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.browser.CloseSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return c.Status(404).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, sessionID)
|
||||
s.mu.Unlock()
|
||||
|
||||
return c.JSON(fiber.Map{"success": true})
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleGetSession(c *fiber.Ctx) error {
|
||||
sessionID := c.Params("id")
|
||||
|
||||
s.mu.RLock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return c.Status(404).JSON(fiber.Map{"error": "Session not found"})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"sessionId": session.ID,
|
||||
"createdAt": session.CreatedAt,
|
||||
"lastActive": session.LastActive,
|
||||
"screenshots": session.Screenshots,
|
||||
"actions": len(session.Actions),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleListSessions(c *fiber.Ctx) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
sessions := make([]map[string]interface{}, 0, len(s.sessions))
|
||||
for _, session := range s.sessions {
|
||||
sessions = append(sessions, map[string]interface{}{
|
||||
"sessionId": session.ID,
|
||||
"createdAt": session.CreatedAt,
|
||||
"lastActive": session.LastActive,
|
||||
"actions": len(session.Actions),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"sessions": sessions, "count": len(sessions)})
|
||||
}
|
||||
|
||||
func (s *BrowserServer) handleAction(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Action string `json:"action"`
|
||||
Selector string `json:"selector,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Value string `json:"value,omitempty"`
|
||||
Screenshot bool `json:"screenshot"`
|
||||
}
|
||||
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s.mu.Lock()
|
||||
if session, ok := s.sessions[req.SessionID]; ok {
|
||||
session.LastActive = time.Now()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
var result *ActionResponse
|
||||
var err error
|
||||
|
||||
switch req.Action {
|
||||
case "navigate":
|
||||
result, err = s.browser.Navigate(ctx, req.SessionID, req.URL, NavigateOptions{Screenshot: req.Screenshot})
|
||||
case "click":
|
||||
result, err = s.browser.Click(ctx, req.SessionID, req.Selector, ClickOptions{Screenshot: req.Screenshot})
|
||||
case "type":
|
||||
result, err = s.browser.Type(ctx, req.SessionID, req.Selector, req.Value, TypeOptions{})
|
||||
case "fill":
|
||||
result, err = s.browser.Fill(ctx, req.SessionID, req.Selector, req.Value)
|
||||
case "screenshot":
|
||||
var screenshot *ScreenshotResult
|
||||
screenshot, err = s.browser.Screenshot(ctx, req.SessionID, ScreenshotOptions{})
|
||||
if err == nil {
|
||||
result = &ActionResponse{Success: true, Screenshot: screenshot.Data}
|
||||
}
|
||||
case "extract":
|
||||
var text string
|
||||
text, err = s.browser.ExtractText(ctx, req.SessionID, req.Selector)
|
||||
result = &ActionResponse{Success: err == nil, Data: text}
|
||||
default:
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Unknown action: " + req.Action})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return c.Status(500).JSON(fiber.Map{"error": err.Error(), "success": false})
|
||||
}
|
||||
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (s *BrowserServer) cleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.cleanupExpiredSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BrowserServer) cleanupExpiredSessions() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for sessionID, session := range s.sessions {
|
||||
if now.Sub(session.LastActive) > s.config.SessionTimeout {
|
||||
log.Printf("[BrowserServer] Cleaning up expired session: %s", sessionID)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
s.browser.CloseSession(ctx, sessionID)
|
||||
cancel()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func errToString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func getBool(m map[string]interface{}, key string) bool {
|
||||
if v, ok := m[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getInt(m map[string]interface{}, key string) int {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return int(v)
|
||||
}
|
||||
if v, ok := m[key].(int); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getStringArray(m map[string]interface{}, key string) []string {
|
||||
if v, ok := m[key].([]interface{}); ok {
|
||||
result := make([]string, len(v))
|
||||
for i, item := range v {
|
||||
result[i], _ = item.(string)
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
738
backend/internal/computer/computer.go
Normal file
738
backend/internal/computer/computer.go
Normal file
@@ -0,0 +1,738 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer/connectors"
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ComputerConfig struct {
|
||||
MaxParallelTasks int
|
||||
MaxSubTasks int
|
||||
TaskTimeout time.Duration
|
||||
SubTaskTimeout time.Duration
|
||||
TotalBudget float64
|
||||
EnableSandbox bool
|
||||
EnableScheduling bool
|
||||
EnableBrowser bool
|
||||
SandboxImage string
|
||||
ArtifactStorageURL string
|
||||
BrowserServerURL string
|
||||
CheckpointStorePath string
|
||||
MaxConcurrentTasks int
|
||||
HeartbeatInterval time.Duration
|
||||
CheckpointInterval time.Duration
|
||||
}
|
||||
|
||||
func DefaultConfig() ComputerConfig {
|
||||
return ComputerConfig{
|
||||
MaxParallelTasks: 10,
|
||||
MaxSubTasks: 100,
|
||||
TaskTimeout: 365 * 24 * time.Hour,
|
||||
SubTaskTimeout: 2 * time.Hour,
|
||||
TotalBudget: 100.0,
|
||||
EnableSandbox: true,
|
||||
EnableScheduling: true,
|
||||
EnableBrowser: true,
|
||||
SandboxImage: "gooseek/sandbox:latest",
|
||||
BrowserServerURL: "http://browser-svc:3050",
|
||||
CheckpointStorePath: "/data/checkpoints",
|
||||
MaxConcurrentTasks: 50,
|
||||
HeartbeatInterval: 30 * time.Second,
|
||||
CheckpointInterval: 15 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func GetDurationConfig(mode DurationMode) (maxDuration, checkpointFreq, heartbeatFreq time.Duration, maxIter int) {
|
||||
cfg, ok := DurationModeConfigs[mode]
|
||||
if !ok {
|
||||
cfg = DurationModeConfigs[DurationMedium]
|
||||
}
|
||||
return cfg.MaxDuration, cfg.CheckpointFreq, cfg.HeartbeatFreq, cfg.MaxIterations
|
||||
}
|
||||
|
||||
type Dependencies struct {
|
||||
Registry *llm.ModelRegistry
|
||||
TaskRepo TaskRepository
|
||||
MemoryRepo MemoryRepository
|
||||
ArtifactRepo ArtifactRepository
|
||||
}
|
||||
|
||||
type TaskRepository interface {
|
||||
Create(ctx context.Context, task *ComputerTask) error
|
||||
Update(ctx context.Context, task *ComputerTask) error
|
||||
GetByID(ctx context.Context, id string) (*ComputerTask, error)
|
||||
GetByUserID(ctx context.Context, userID string, limit, offset int) ([]ComputerTask, error)
|
||||
GetScheduled(ctx context.Context) ([]ComputerTask, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type MemoryRepository interface {
|
||||
Store(ctx context.Context, entry *MemoryEntry) error
|
||||
GetByUser(ctx context.Context, userID string, limit int) ([]MemoryEntry, error)
|
||||
GetByTask(ctx context.Context, taskID string) ([]MemoryEntry, error)
|
||||
Search(ctx context.Context, userID, query string, limit int) ([]MemoryEntry, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type ArtifactRepository interface {
|
||||
Create(ctx context.Context, artifact *Artifact) error
|
||||
GetByID(ctx context.Context, id string) (*Artifact, error)
|
||||
GetByTaskID(ctx context.Context, taskID string) ([]Artifact, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type Computer struct {
|
||||
cfg ComputerConfig
|
||||
planner *Planner
|
||||
router *Router
|
||||
executor *Executor
|
||||
sandbox *SandboxManager
|
||||
memory *MemoryStore
|
||||
scheduler *Scheduler
|
||||
connectors *connectors.ConnectorHub
|
||||
registry *llm.ModelRegistry
|
||||
taskRepo TaskRepository
|
||||
eventBus *EventBus
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*ComputerTask
|
||||
}
|
||||
|
||||
func NewComputer(cfg ComputerConfig, deps Dependencies) *Computer {
|
||||
eventBus := NewEventBus()
|
||||
|
||||
c := &Computer{
|
||||
cfg: cfg,
|
||||
registry: deps.Registry,
|
||||
taskRepo: deps.TaskRepo,
|
||||
eventBus: eventBus,
|
||||
tasks: make(map[string]*ComputerTask),
|
||||
}
|
||||
|
||||
c.planner = NewPlanner(deps.Registry)
|
||||
c.router = NewRouter(deps.Registry)
|
||||
c.executor = NewExecutor(c.router, cfg.MaxParallelTasks)
|
||||
c.memory = NewMemoryStore(deps.MemoryRepo)
|
||||
c.connectors = connectors.NewConnectorHub()
|
||||
|
||||
if cfg.EnableSandbox {
|
||||
c.sandbox = NewSandboxManager(SandboxConfig{
|
||||
Image: cfg.SandboxImage,
|
||||
Timeout: cfg.SubTaskTimeout,
|
||||
})
|
||||
c.executor.SetSandbox(c.sandbox)
|
||||
}
|
||||
|
||||
if cfg.EnableScheduling {
|
||||
c.scheduler = NewScheduler(deps.TaskRepo, c)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Computer) Execute(ctx context.Context, userID, query string, opts ExecuteOptions) (*ComputerTask, error) {
|
||||
if opts.ResumeFromID != "" {
|
||||
return c.resumeFromCheckpoint(ctx, opts.ResumeFromID, opts)
|
||||
}
|
||||
|
||||
durationMode := opts.DurationMode
|
||||
if durationMode == "" {
|
||||
durationMode = DurationMedium
|
||||
}
|
||||
|
||||
maxDuration, _, _, maxIter := GetDurationConfig(durationMode)
|
||||
|
||||
task := &ComputerTask{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Query: query,
|
||||
Status: StatusPending,
|
||||
Memory: make(map[string]interface{}),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
DurationMode: durationMode,
|
||||
MaxDuration: maxDuration,
|
||||
MaxIterations: maxIter,
|
||||
Priority: opts.Priority,
|
||||
}
|
||||
|
||||
if opts.Priority == "" {
|
||||
task.Priority = PriorityNormal
|
||||
}
|
||||
|
||||
if opts.ResourceLimits != nil {
|
||||
task.ResourceLimits = opts.ResourceLimits
|
||||
}
|
||||
|
||||
if opts.Schedule != nil {
|
||||
task.Schedule = opts.Schedule
|
||||
task.Status = StatusScheduled
|
||||
}
|
||||
|
||||
if opts.Context != nil {
|
||||
task.Memory = opts.Context
|
||||
}
|
||||
|
||||
estimatedEnd := time.Now().Add(maxDuration)
|
||||
task.EstimatedEnd = &estimatedEnd
|
||||
|
||||
if err := c.taskRepo.Create(ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.tasks[task.ID] = task
|
||||
c.mu.Unlock()
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskCreated,
|
||||
TaskID: task.ID,
|
||||
Status: task.Status,
|
||||
Message: fmt.Sprintf("Task created (mode: %s, max duration: %v)", durationMode, maxDuration),
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"durationMode": durationMode,
|
||||
"maxDuration": maxDuration.String(),
|
||||
"maxIterations": maxIter,
|
||||
},
|
||||
})
|
||||
|
||||
if opts.Async {
|
||||
go c.executeTaskWithCheckpoints(context.Background(), task, opts)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
return c.executeTaskWithCheckpoints(ctx, task, opts)
|
||||
}
|
||||
|
||||
func (c *Computer) resumeFromCheckpoint(ctx context.Context, checkpointID string, opts ExecuteOptions) (*ComputerTask, error) {
|
||||
task, err := c.taskRepo.GetByID(ctx, checkpointID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("task not found: %w", err)
|
||||
}
|
||||
|
||||
if task.Checkpoint == nil {
|
||||
return nil, errors.New("no checkpoint found for this task")
|
||||
}
|
||||
|
||||
task.Status = StatusExecuting
|
||||
now := time.Now()
|
||||
task.ResumedAt = &now
|
||||
task.UpdatedAt = now
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventResumed,
|
||||
TaskID: task.ID,
|
||||
Status: task.Status,
|
||||
Message: fmt.Sprintf("Resumed from checkpoint (wave: %d, subtask: %d)", task.Checkpoint.WaveIndex, task.Checkpoint.SubTaskIndex),
|
||||
Progress: task.Checkpoint.Progress,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
c.mu.Lock()
|
||||
c.tasks[task.ID] = task
|
||||
c.mu.Unlock()
|
||||
|
||||
if opts.Async {
|
||||
go c.executeTaskWithCheckpoints(context.Background(), task, opts)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
return c.executeTaskWithCheckpoints(ctx, task, opts)
|
||||
}
|
||||
|
||||
func (c *Computer) executeTask(ctx context.Context, task *ComputerTask, opts ExecuteOptions) (*ComputerTask, error) {
|
||||
return c.executeTaskWithCheckpoints(ctx, task, opts)
|
||||
}
|
||||
|
||||
func (c *Computer) executeTaskWithCheckpoints(ctx context.Context, task *ComputerTask, opts ExecuteOptions) (*ComputerTask, error) {
|
||||
maxDuration, checkpointFreq, heartbeatFreq, _ := GetDurationConfig(task.DurationMode)
|
||||
|
||||
if opts.Timeout > 0 {
|
||||
maxDuration = time.Duration(opts.Timeout) * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, maxDuration)
|
||||
defer cancel()
|
||||
|
||||
budget := c.cfg.TotalBudget
|
||||
if opts.MaxCost > 0 {
|
||||
budget = opts.MaxCost
|
||||
}
|
||||
if task.ResourceLimits != nil && task.ResourceLimits.MaxTotalCost > 0 {
|
||||
budget = task.ResourceLimits.MaxTotalCost
|
||||
}
|
||||
|
||||
startWave := 0
|
||||
if task.Checkpoint != nil {
|
||||
startWave = task.Checkpoint.WaveIndex
|
||||
for k, v := range task.Checkpoint.Memory {
|
||||
task.Memory[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if task.Plan == nil {
|
||||
task.Status = StatusPlanning
|
||||
task.UpdatedAt = time.Now()
|
||||
c.updateTask(ctx, task)
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskStarted,
|
||||
TaskID: task.ID,
|
||||
Status: StatusPlanning,
|
||||
Message: "Planning task execution",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
userMemory, _ := c.memory.GetUserContext(ctx, task.UserID)
|
||||
memoryContext := make(map[string]interface{})
|
||||
for k, v := range userMemory {
|
||||
memoryContext[k] = v
|
||||
}
|
||||
for k, v := range task.Memory {
|
||||
memoryContext[k] = v
|
||||
}
|
||||
|
||||
plan, err := c.planner.Plan(ctx, task.Query, memoryContext)
|
||||
if err != nil {
|
||||
task.Status = StatusFailed
|
||||
task.Error = fmt.Sprintf("Planning failed: %v", err)
|
||||
task.UpdatedAt = time.Now()
|
||||
c.updateTask(ctx, task)
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskFailed,
|
||||
TaskID: task.ID,
|
||||
Status: StatusFailed,
|
||||
Message: task.Error,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return task, err
|
||||
}
|
||||
|
||||
task.Plan = plan
|
||||
task.SubTasks = plan.SubTasks
|
||||
}
|
||||
|
||||
task.Status = StatusLongRunning
|
||||
task.UpdatedAt = time.Now()
|
||||
c.updateTask(ctx, task)
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskProgress,
|
||||
TaskID: task.ID,
|
||||
Status: StatusLongRunning,
|
||||
Progress: 10,
|
||||
Message: fmt.Sprintf("Executing %d subtasks (long-running mode)", len(task.Plan.SubTasks)),
|
||||
Data: map[string]interface{}{
|
||||
"plan": task.Plan,
|
||||
"durationMode": task.DurationMode,
|
||||
"checkpointFreq": checkpointFreq.String(),
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
heartbeatTicker := time.NewTicker(heartbeatFreq)
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
checkpointTicker := time.NewTicker(checkpointFreq)
|
||||
defer checkpointTicker.Stop()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-heartbeatTicker.C:
|
||||
now := time.Now()
|
||||
task.HeartbeatAt = &now
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventHeartbeat,
|
||||
TaskID: task.ID,
|
||||
Progress: task.Progress,
|
||||
Message: fmt.Sprintf("Heartbeat: %d%% complete, cost: $%.4f", task.Progress, task.TotalCost),
|
||||
Data: map[string]interface{}{
|
||||
"runtime": time.Since(task.CreatedAt).String(),
|
||||
"cost": task.TotalCost,
|
||||
},
|
||||
Timestamp: now,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
totalSubTasks := len(task.Plan.ExecutionOrder)
|
||||
for waveIdx := startWave; waveIdx < totalSubTasks; waveIdx++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.saveCheckpoint(task, waveIdx, 0, "context_timeout")
|
||||
return task, ctx.Err()
|
||||
case <-checkpointTicker.C:
|
||||
c.saveCheckpoint(task, waveIdx, 0, "periodic")
|
||||
default:
|
||||
}
|
||||
|
||||
if budget > 0 && task.TotalCost >= budget {
|
||||
c.saveCheckpoint(task, waveIdx, 0, "budget_exceeded")
|
||||
task.Status = StatusPaused
|
||||
task.Message = fmt.Sprintf("Paused: budget exceeded ($%.2f / $%.2f)", task.TotalCost, budget)
|
||||
c.updateTask(ctx, task)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
wave := task.Plan.ExecutionOrder[waveIdx]
|
||||
waveTasks := make([]SubTask, 0)
|
||||
for _, subTaskID := range wave {
|
||||
for i := range task.SubTasks {
|
||||
if task.SubTasks[i].ID == subTaskID {
|
||||
waveTasks = append(waveTasks, task.SubTasks[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results, err := c.executor.ExecuteGroup(ctx, waveTasks, budget-task.TotalCost)
|
||||
if err != nil {
|
||||
c.saveCheckpoint(task, waveIdx, 0, "execution_error")
|
||||
task.Status = StatusFailed
|
||||
task.Error = fmt.Sprintf("Execution failed at wave %d: %v", waveIdx, err)
|
||||
task.UpdatedAt = time.Now()
|
||||
c.updateTask(ctx, task)
|
||||
return task, err
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
for i := range task.SubTasks {
|
||||
if task.SubTasks[i].ID == result.SubTaskID {
|
||||
task.SubTasks[i].Output = result.Output
|
||||
task.SubTasks[i].Cost = result.Cost
|
||||
task.SubTasks[i].Status = StatusCompleted
|
||||
now := time.Now()
|
||||
task.SubTasks[i].CompletedAt = &now
|
||||
if result.Error != nil {
|
||||
task.SubTasks[i].Status = StatusFailed
|
||||
task.SubTasks[i].Error = result.Error.Error()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
task.TotalCost += result.Cost
|
||||
task.TotalRuntime = time.Since(task.CreatedAt)
|
||||
|
||||
for _, artifact := range result.Artifacts {
|
||||
task.Artifacts = append(task.Artifacts, artifact)
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventArtifact,
|
||||
TaskID: task.ID,
|
||||
SubTaskID: result.SubTaskID,
|
||||
Data: map[string]interface{}{
|
||||
"artifact": artifact,
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
progress := 10 + int(float64(waveIdx+1)/float64(totalSubTasks)*80)
|
||||
task.Progress = progress
|
||||
task.Iterations = waveIdx + 1
|
||||
task.UpdatedAt = time.Now()
|
||||
c.updateTask(ctx, task)
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventIteration,
|
||||
TaskID: task.ID,
|
||||
Progress: progress,
|
||||
Message: fmt.Sprintf("Completed wave %d/%d (runtime: %v)", waveIdx+1, totalSubTasks, time.Since(task.CreatedAt).Round(time.Second)),
|
||||
Data: map[string]interface{}{
|
||||
"wave": waveIdx + 1,
|
||||
"total": totalSubTasks,
|
||||
"cost": task.TotalCost,
|
||||
"runtime": time.Since(task.CreatedAt).String(),
|
||||
"artifacts": len(task.Artifacts),
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
task.Status = StatusCompleted
|
||||
task.Progress = 100
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
task.UpdatedAt = now
|
||||
task.TotalRuntime = time.Since(task.CreatedAt)
|
||||
c.updateTask(ctx, task)
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskCompleted,
|
||||
TaskID: task.ID,
|
||||
Status: StatusCompleted,
|
||||
Progress: 100,
|
||||
Message: fmt.Sprintf("Task completed (runtime: %v, cost: $%.4f)", task.TotalRuntime.Round(time.Second), task.TotalCost),
|
||||
Data: map[string]interface{}{
|
||||
"artifacts": task.Artifacts,
|
||||
"totalCost": task.TotalCost,
|
||||
"totalRuntime": task.TotalRuntime.String(),
|
||||
"iterations": task.Iterations,
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
c.storeTaskResults(ctx, task)
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (c *Computer) saveCheckpoint(task *ComputerTask, waveIdx, subTaskIdx int, reason string) {
|
||||
checkpoint := Checkpoint{
|
||||
ID: uuid.New().String(),
|
||||
TaskID: task.ID,
|
||||
WaveIndex: waveIdx,
|
||||
SubTaskIndex: subTaskIdx,
|
||||
State: make(map[string]interface{}),
|
||||
Progress: task.Progress,
|
||||
Memory: task.Memory,
|
||||
CreatedAt: time.Now(),
|
||||
RuntimeSoFar: time.Since(task.CreatedAt),
|
||||
CostSoFar: task.TotalCost,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
for _, artifact := range task.Artifacts {
|
||||
checkpoint.Artifacts = append(checkpoint.Artifacts, artifact.ID)
|
||||
}
|
||||
|
||||
task.Checkpoint = &checkpoint
|
||||
task.Checkpoints = append(task.Checkpoints, checkpoint)
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
c.taskRepo.Update(ctx, task)
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventCheckpointSaved,
|
||||
TaskID: task.ID,
|
||||
Progress: task.Progress,
|
||||
Message: fmt.Sprintf("Checkpoint saved: %s (wave %d)", reason, waveIdx),
|
||||
Data: map[string]interface{}{
|
||||
"checkpointId": checkpoint.ID,
|
||||
"waveIndex": waveIdx,
|
||||
"subTaskIndex": subTaskIdx,
|
||||
"reason": reason,
|
||||
"runtime": checkpoint.RuntimeSoFar.String(),
|
||||
"cost": checkpoint.CostSoFar,
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Computer) Pause(ctx context.Context, taskID string) error {
|
||||
c.mu.Lock()
|
||||
task, ok := c.tasks[taskID]
|
||||
if !ok {
|
||||
c.mu.Unlock()
|
||||
var err error
|
||||
task, err = c.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
}
|
||||
|
||||
if task.Status != StatusExecuting && task.Status != StatusLongRunning {
|
||||
c.mu.Unlock()
|
||||
return errors.New("task is not running")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
task.Status = StatusPaused
|
||||
task.PausedAt = &now
|
||||
task.UpdatedAt = now
|
||||
c.mu.Unlock()
|
||||
|
||||
c.saveCheckpoint(task, task.Iterations, 0, "user_paused")
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventPaused,
|
||||
TaskID: taskID,
|
||||
Status: StatusPaused,
|
||||
Progress: task.Progress,
|
||||
Message: "Task paused by user",
|
||||
Timestamp: now,
|
||||
})
|
||||
|
||||
return c.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (c *Computer) Resume(ctx context.Context, taskID string, userInput string) error {
|
||||
c.mu.RLock()
|
||||
task, ok := c.tasks[taskID]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
var err error
|
||||
task, err = c.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("task not found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if task.Status != StatusWaiting {
|
||||
return errors.New("task is not waiting for user input")
|
||||
}
|
||||
|
||||
task.Memory["user_input"] = userInput
|
||||
task.Status = StatusExecuting
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
go c.executeTask(context.Background(), task, ExecuteOptions{Async: true})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Computer) Cancel(ctx context.Context, taskID string) error {
|
||||
c.mu.Lock()
|
||||
task, ok := c.tasks[taskID]
|
||||
if ok {
|
||||
task.Status = StatusCancelled
|
||||
task.UpdatedAt = time.Now()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
task, err := c.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("task not found: %w", err)
|
||||
}
|
||||
task.Status = StatusCancelled
|
||||
task.UpdatedAt = time.Now()
|
||||
return c.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
c.emitEvent(TaskEvent{
|
||||
Type: EventTaskFailed,
|
||||
TaskID: taskID,
|
||||
Status: StatusCancelled,
|
||||
Message: "Task cancelled by user",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return c.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (c *Computer) GetStatus(ctx context.Context, taskID string) (*ComputerTask, error) {
|
||||
c.mu.RLock()
|
||||
task, ok := c.tasks[taskID]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if ok {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
return c.taskRepo.GetByID(ctx, taskID)
|
||||
}
|
||||
|
||||
func (c *Computer) GetUserTasks(ctx context.Context, userID string, limit, offset int) ([]ComputerTask, error) {
|
||||
return c.taskRepo.GetByUserID(ctx, userID, limit, offset)
|
||||
}
|
||||
|
||||
func (c *Computer) Stream(ctx context.Context, taskID string) (<-chan TaskEvent, error) {
|
||||
return c.eventBus.Subscribe(taskID), nil
|
||||
}
|
||||
|
||||
func (c *Computer) updateTask(ctx context.Context, task *ComputerTask) {
|
||||
c.mu.Lock()
|
||||
c.tasks[task.ID] = task
|
||||
c.mu.Unlock()
|
||||
|
||||
_ = c.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (c *Computer) emitEvent(event TaskEvent) {
|
||||
c.eventBus.Publish(event.TaskID, event)
|
||||
}
|
||||
|
||||
func (c *Computer) storeTaskResults(ctx context.Context, task *ComputerTask) {
|
||||
for _, st := range task.SubTasks {
|
||||
if st.Output != nil {
|
||||
outputJSON, _ := json.Marshal(st.Output)
|
||||
entry := &MemoryEntry{
|
||||
ID: uuid.New().String(),
|
||||
UserID: task.UserID,
|
||||
TaskID: task.ID,
|
||||
Key: fmt.Sprintf("subtask_%s_result", st.ID),
|
||||
Value: string(outputJSON),
|
||||
Type: MemoryTypeResult,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
_ = c.memory.Store(ctx, task.UserID, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Computer) StartScheduler(ctx context.Context) {
|
||||
if c.scheduler != nil {
|
||||
c.scheduler.Start(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Computer) StopScheduler() {
|
||||
if c.scheduler != nil {
|
||||
c.scheduler.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
type EventBus struct {
|
||||
subscribers map[string][]chan TaskEvent
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{
|
||||
subscribers: make(map[string][]chan TaskEvent),
|
||||
}
|
||||
}
|
||||
|
||||
func (eb *EventBus) Subscribe(taskID string) <-chan TaskEvent {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
ch := make(chan TaskEvent, 100)
|
||||
eb.subscribers[taskID] = append(eb.subscribers[taskID], ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (eb *EventBus) Unsubscribe(taskID string, ch <-chan TaskEvent) {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
subs := eb.subscribers[taskID]
|
||||
for i, sub := range subs {
|
||||
if sub == ch {
|
||||
eb.subscribers[taskID] = append(subs[:i], subs[i+1:]...)
|
||||
close(sub)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (eb *EventBus) Publish(taskID string, event TaskEvent) {
|
||||
eb.mu.RLock()
|
||||
subs := eb.subscribers[taskID]
|
||||
eb.mu.RUnlock()
|
||||
|
||||
for _, ch := range subs {
|
||||
select {
|
||||
case ch <- event:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
104
backend/internal/computer/connectors/connector.go
Normal file
104
backend/internal/computer/connectors/connector.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Connector interface {
|
||||
ID() string
|
||||
Name() string
|
||||
Description() string
|
||||
Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error)
|
||||
GetActions() []Action
|
||||
Validate(params map[string]interface{}) error
|
||||
}
|
||||
|
||||
type Action struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Schema map[string]interface{} `json:"schema"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
|
||||
type ConnectorHub struct {
|
||||
connectors map[string]Connector
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewConnectorHub() *ConnectorHub {
|
||||
return &ConnectorHub{
|
||||
connectors: make(map[string]Connector),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) Register(connector Connector) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.connectors[connector.ID()] = connector
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) Unregister(id string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.connectors, id)
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) Get(id string) (Connector, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
connector, ok := h.connectors[id]
|
||||
if !ok {
|
||||
return nil, errors.New("connector not found: " + id)
|
||||
}
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) List() []Connector {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
result := make([]Connector, 0, len(h.connectors))
|
||||
for _, c := range h.connectors {
|
||||
result = append(result, c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) Execute(ctx context.Context, connectorID, action string, params map[string]interface{}) (interface{}, error) {
|
||||
connector, err := h.Get(connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := connector.Validate(params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connector.Execute(ctx, action, params)
|
||||
}
|
||||
|
||||
type ConnectorInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Actions []Action `json:"actions"`
|
||||
}
|
||||
|
||||
func (h *ConnectorHub) GetInfo() []ConnectorInfo {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
result := make([]ConnectorInfo, 0, len(h.connectors))
|
||||
for _, c := range h.connectors {
|
||||
result = append(result, ConnectorInfo{
|
||||
ID: c.ID(),
|
||||
Name: c.Name(),
|
||||
Description: c.Description(),
|
||||
Actions: c.GetActions(),
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
215
backend/internal/computer/connectors/email.go
Normal file
215
backend/internal/computer/connectors/email.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EmailConfig struct {
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
Username string
|
||||
Password string
|
||||
FromAddress string
|
||||
FromName string
|
||||
UseTLS bool
|
||||
AllowHTML bool
|
||||
}
|
||||
|
||||
type EmailConnector struct {
|
||||
cfg EmailConfig
|
||||
}
|
||||
|
||||
func NewEmailConnector(cfg EmailConfig) *EmailConnector {
|
||||
return &EmailConnector{cfg: cfg}
|
||||
}
|
||||
|
||||
func (e *EmailConnector) ID() string {
|
||||
return "email"
|
||||
}
|
||||
|
||||
func (e *EmailConnector) Name() string {
|
||||
return "Email"
|
||||
}
|
||||
|
||||
func (e *EmailConnector) Description() string {
|
||||
return "Send emails via SMTP"
|
||||
}
|
||||
|
||||
func (e *EmailConnector) GetActions() []Action {
|
||||
return []Action{
|
||||
{
|
||||
Name: "send",
|
||||
Description: "Send an email",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"to": map[string]interface{}{"type": "string", "description": "Recipient email address"},
|
||||
"subject": map[string]interface{}{"type": "string", "description": "Email subject"},
|
||||
"body": map[string]interface{}{"type": "string", "description": "Email body"},
|
||||
"html": map[string]interface{}{"type": "boolean", "description": "Whether body is HTML"},
|
||||
"cc": map[string]interface{}{"type": "string", "description": "CC recipients (comma-separated)"},
|
||||
"bcc": map[string]interface{}{"type": "string", "description": "BCC recipients (comma-separated)"},
|
||||
},
|
||||
},
|
||||
Required: []string{"to", "subject", "body"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailConnector) Validate(params map[string]interface{}) error {
|
||||
if _, ok := params["to"]; !ok {
|
||||
return errors.New("'to' is required")
|
||||
}
|
||||
if _, ok := params["subject"]; !ok {
|
||||
return errors.New("'subject' is required")
|
||||
}
|
||||
if _, ok := params["body"]; !ok {
|
||||
return errors.New("'body' is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EmailConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
|
||||
switch action {
|
||||
case "send":
|
||||
return e.send(ctx, params)
|
||||
default:
|
||||
return nil, errors.New("unknown action: " + action)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailConnector) send(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
to := params["to"].(string)
|
||||
subject := params["subject"].(string)
|
||||
body := params["body"].(string)
|
||||
|
||||
isHTML := false
|
||||
if html, ok := params["html"].(bool); ok {
|
||||
isHTML = html && e.cfg.AllowHTML
|
||||
}
|
||||
|
||||
var cc, bcc []string
|
||||
if ccStr, ok := params["cc"].(string); ok && ccStr != "" {
|
||||
cc = strings.Split(ccStr, ",")
|
||||
for i := range cc {
|
||||
cc[i] = strings.TrimSpace(cc[i])
|
||||
}
|
||||
}
|
||||
if bccStr, ok := params["bcc"].(string); ok && bccStr != "" {
|
||||
bcc = strings.Split(bccStr, ",")
|
||||
for i := range bcc {
|
||||
bcc[i] = strings.TrimSpace(bcc[i])
|
||||
}
|
||||
}
|
||||
|
||||
from := e.cfg.FromAddress
|
||||
if e.cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", e.cfg.FromName, e.cfg.FromAddress)
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(fmt.Sprintf("From: %s\r\n", from))
|
||||
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
|
||||
if len(cc) > 0 {
|
||||
msg.WriteString(fmt.Sprintf("Cc: %s\r\n", strings.Join(cc, ", ")))
|
||||
}
|
||||
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
|
||||
msg.WriteString("MIME-Version: 1.0\r\n")
|
||||
|
||||
if isHTML {
|
||||
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
|
||||
} else {
|
||||
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
|
||||
}
|
||||
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString(body)
|
||||
|
||||
recipients := []string{to}
|
||||
recipients = append(recipients, cc...)
|
||||
recipients = append(recipients, bcc...)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", e.cfg.SMTPHost, e.cfg.SMTPPort)
|
||||
|
||||
var auth smtp.Auth
|
||||
if e.cfg.Username != "" && e.cfg.Password != "" {
|
||||
auth = smtp.PlainAuth("", e.cfg.Username, e.cfg.Password, e.cfg.SMTPHost)
|
||||
}
|
||||
|
||||
var err error
|
||||
if e.cfg.UseTLS {
|
||||
err = e.sendWithTLS(addr, auth, e.cfg.FromAddress, recipients, []byte(msg.String()))
|
||||
} else {
|
||||
err = smtp.SendMail(addr, auth, e.cfg.FromAddress, recipients, []byte(msg.String()))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
}, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"recipients": len(recipients),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EmailConnector) sendWithTLS(addr string, auth smtp.Auth, from string, to []string, msg []byte) error {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: e.cfg.SMTPHost,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, e.cfg.SMTPHost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if auth != nil {
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, recipient := range to {
|
||||
if err := client.Rcpt(recipient); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
432
backend/internal/computer/connectors/storage.go
Normal file
432
backend/internal/computer/connectors/storage.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
type StorageConfig struct {
|
||||
Endpoint string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
BucketName string
|
||||
UseSSL bool
|
||||
Region string
|
||||
PublicURL string
|
||||
}
|
||||
|
||||
type StorageConnector struct {
|
||||
cfg StorageConfig
|
||||
client *minio.Client
|
||||
}
|
||||
|
||||
func NewStorageConnector(cfg StorageConfig) (*StorageConnector, error) {
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
Region: cfg.Region,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage client: %w", err)
|
||||
}
|
||||
|
||||
return &StorageConnector{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) ID() string {
|
||||
return "storage"
|
||||
}
|
||||
|
||||
func (s *StorageConnector) Name() string {
|
||||
return "Storage"
|
||||
}
|
||||
|
||||
func (s *StorageConnector) Description() string {
|
||||
return "Store and retrieve files from S3-compatible storage"
|
||||
}
|
||||
|
||||
func (s *StorageConnector) GetActions() []Action {
|
||||
return []Action{
|
||||
{
|
||||
Name: "upload",
|
||||
Description: "Upload a file",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
|
||||
"content": map[string]interface{}{"type": "string", "description": "File content (base64 or text)"},
|
||||
"content_type": map[string]interface{}{"type": "string", "description": "MIME type"},
|
||||
"public": map[string]interface{}{"type": "boolean", "description": "Make file publicly accessible"},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
{
|
||||
Name: "download",
|
||||
Description: "Download a file",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
{
|
||||
Name: "delete",
|
||||
Description: "Delete a file",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
{
|
||||
Name: "list",
|
||||
Description: "List files in a directory",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"prefix": map[string]interface{}{"type": "string", "description": "Path prefix"},
|
||||
"limit": map[string]interface{}{"type": "integer", "description": "Max results"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "get_url",
|
||||
Description: "Get a presigned URL for a file",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "Storage path/key"},
|
||||
"expires": map[string]interface{}{"type": "integer", "description": "URL expiry in seconds"},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StorageConnector) Validate(params map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
|
||||
switch action {
|
||||
case "upload":
|
||||
return s.upload(ctx, params)
|
||||
case "download":
|
||||
return s.download(ctx, params)
|
||||
case "delete":
|
||||
return s.deleteFile(ctx, params)
|
||||
case "list":
|
||||
return s.list(ctx, params)
|
||||
case "get_url":
|
||||
return s.getURL(ctx, params)
|
||||
default:
|
||||
return nil, errors.New("unknown action: " + action)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StorageConnector) upload(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
path := params["path"].(string)
|
||||
content := params["content"].(string)
|
||||
|
||||
contentType := "application/octet-stream"
|
||||
if ct, ok := params["content_type"].(string); ok {
|
||||
contentType = ct
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = s.detectContentType(path)
|
||||
}
|
||||
|
||||
reader := bytes.NewReader([]byte(content))
|
||||
size := int64(len(content))
|
||||
|
||||
info, err := s.client.PutObject(ctx, s.cfg.BucketName, path, reader, size, minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
url := ""
|
||||
if s.cfg.PublicURL != "" {
|
||||
url = fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"path": path,
|
||||
"size": info.Size,
|
||||
"etag": info.ETag,
|
||||
"url": url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) UploadBytes(ctx context.Context, path string, content []byte, contentType string) (string, error) {
|
||||
if contentType == "" {
|
||||
contentType = s.detectContentType(path)
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(content)
|
||||
size := int64(len(content))
|
||||
|
||||
_, err := s.client.PutObject(ctx, s.cfg.BucketName, path, reader, size, minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s.cfg.PublicURL != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path), nil
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) download(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
path := params["path"].(string)
|
||||
|
||||
obj, err := s.client.GetObject(ctx, s.cfg.BucketName, path, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
content, err := io.ReadAll(obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read failed: %w", err)
|
||||
}
|
||||
|
||||
stat, _ := obj.Stat()
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"path": path,
|
||||
"content": string(content),
|
||||
"size": len(content),
|
||||
"content_type": stat.ContentType,
|
||||
"modified": stat.LastModified,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) DownloadBytes(ctx context.Context, path string) ([]byte, error) {
|
||||
obj, err := s.client.GetObject(ctx, s.cfg.BucketName, path, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
return io.ReadAll(obj)
|
||||
}
|
||||
|
||||
func (s *StorageConnector) deleteFile(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
path := params["path"].(string)
|
||||
|
||||
err := s.client.RemoveObject(ctx, s.cfg.BucketName, path, minio.RemoveObjectOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"path": path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) list(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
prefix := ""
|
||||
if p, ok := params["prefix"].(string); ok {
|
||||
prefix = p
|
||||
}
|
||||
|
||||
limit := 100
|
||||
if l, ok := params["limit"].(float64); ok {
|
||||
limit = int(l)
|
||||
}
|
||||
|
||||
objects := s.client.ListObjects(ctx, s.cfg.BucketName, minio.ListObjectsOptions{
|
||||
Prefix: prefix,
|
||||
Recursive: true,
|
||||
})
|
||||
|
||||
var files []map[string]interface{}
|
||||
count := 0
|
||||
|
||||
for obj := range objects {
|
||||
if obj.Err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
files = append(files, map[string]interface{}{
|
||||
"path": obj.Key,
|
||||
"size": obj.Size,
|
||||
"modified": obj.LastModified,
|
||||
"etag": obj.ETag,
|
||||
})
|
||||
|
||||
count++
|
||||
if count >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"files": files,
|
||||
"count": len(files),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) getURL(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
path := params["path"].(string)
|
||||
|
||||
expires := 3600
|
||||
if e, ok := params["expires"].(float64); ok {
|
||||
expires = int(e)
|
||||
}
|
||||
|
||||
url, err := s.client.PresignedGetObject(ctx, s.cfg.BucketName, path, time.Duration(expires)*time.Second, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate URL: %w", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"url": url.String(),
|
||||
"expires": expires,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *StorageConnector) GetPublicURL(path string) string {
|
||||
if s.cfg.PublicURL != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.cfg.PublicURL, "/"), s.cfg.BucketName, path)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *StorageConnector) detectContentType(path string) string {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
|
||||
contentTypes := map[string]string{
|
||||
".html": "text/html",
|
||||
".css": "text/css",
|
||||
".js": "application/javascript",
|
||||
".json": "application/json",
|
||||
".xml": "application/xml",
|
||||
".pdf": "application/pdf",
|
||||
".zip": "application/zip",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".svg": "image/svg+xml",
|
||||
".mp4": "video/mp4",
|
||||
".mp3": "audio/mpeg",
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
".csv": "text/csv",
|
||||
".py": "text/x-python",
|
||||
".go": "text/x-go",
|
||||
".rs": "text/x-rust",
|
||||
}
|
||||
|
||||
if ct, ok := contentTypes[ext]; ok {
|
||||
return ct
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func (s *StorageConnector) EnsureBucket(ctx context.Context) error {
|
||||
exists, err := s.client.BucketExists(ctx, s.cfg.BucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return s.client.MakeBucket(ctx, s.cfg.BucketName, minio.MakeBucketOptions{
|
||||
Region: s.cfg.Region,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewLocalStorageConnector(basePath string) *LocalStorageConnector {
|
||||
return &LocalStorageConnector{basePath: basePath}
|
||||
}
|
||||
|
||||
type LocalStorageConnector struct {
|
||||
basePath string
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) ID() string {
|
||||
return "local_storage"
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) Name() string {
|
||||
return "Local Storage"
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) Description() string {
|
||||
return "Store files on local filesystem"
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) GetActions() []Action {
|
||||
return []Action{
|
||||
{Name: "upload", Description: "Upload a file"},
|
||||
{Name: "download", Description: "Download a file"},
|
||||
{Name: "delete", Description: "Delete a file"},
|
||||
{Name: "list", Description: "List files"},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) Validate(params map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorageConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
|
||||
switch action {
|
||||
case "upload":
|
||||
path := params["path"].(string)
|
||||
content := params["content"].(string)
|
||||
fullPath := filepath.Join(l.basePath, path)
|
||||
os.MkdirAll(filepath.Dir(fullPath), 0755)
|
||||
err := os.WriteFile(fullPath, []byte(content), 0644)
|
||||
return map[string]interface{}{"success": err == nil, "path": path}, err
|
||||
|
||||
case "download":
|
||||
path := params["path"].(string)
|
||||
fullPath := filepath.Join(l.basePath, path)
|
||||
content, err := os.ReadFile(fullPath)
|
||||
return map[string]interface{}{"success": err == nil, "content": string(content)}, err
|
||||
|
||||
case "delete":
|
||||
path := params["path"].(string)
|
||||
fullPath := filepath.Join(l.basePath, path)
|
||||
err := os.Remove(fullPath)
|
||||
return map[string]interface{}{"success": err == nil}, err
|
||||
|
||||
default:
|
||||
return nil, errors.New("unknown action")
|
||||
}
|
||||
}
|
||||
263
backend/internal/computer/connectors/telegram.go
Normal file
263
backend/internal/computer/connectors/telegram.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TelegramConfig struct {
|
||||
BotToken string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type TelegramConnector struct {
|
||||
cfg TelegramConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewTelegramConnector(cfg TelegramConfig) *TelegramConnector {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
return &TelegramConnector{
|
||||
cfg: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) ID() string {
|
||||
return "telegram"
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) Name() string {
|
||||
return "Telegram"
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) Description() string {
|
||||
return "Send messages via Telegram Bot API"
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) GetActions() []Action {
|
||||
return []Action{
|
||||
{
|
||||
Name: "send_message",
|
||||
Description: "Send a text message",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID or @username"},
|
||||
"text": map[string]interface{}{"type": "string", "description": "Message text"},
|
||||
"parse_mode": map[string]interface{}{"type": "string", "enum": []string{"HTML", "Markdown", "MarkdownV2"}},
|
||||
},
|
||||
},
|
||||
Required: []string{"chat_id", "text"},
|
||||
},
|
||||
{
|
||||
Name: "send_document",
|
||||
Description: "Send a document/file",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID"},
|
||||
"document": map[string]interface{}{"type": "string", "description": "File path or URL"},
|
||||
"caption": map[string]interface{}{"type": "string", "description": "Document caption"},
|
||||
},
|
||||
},
|
||||
Required: []string{"chat_id", "document"},
|
||||
},
|
||||
{
|
||||
Name: "send_photo",
|
||||
Description: "Send a photo",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"chat_id": map[string]interface{}{"type": "string", "description": "Chat ID"},
|
||||
"photo": map[string]interface{}{"type": "string", "description": "Photo URL or file_id"},
|
||||
"caption": map[string]interface{}{"type": "string", "description": "Photo caption"},
|
||||
},
|
||||
},
|
||||
Required: []string{"chat_id", "photo"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) Validate(params map[string]interface{}) error {
|
||||
if _, ok := params["chat_id"]; !ok {
|
||||
return errors.New("'chat_id' is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
|
||||
switch action {
|
||||
case "send_message":
|
||||
return t.sendMessage(ctx, params)
|
||||
case "send_document":
|
||||
return t.sendDocument(ctx, params)
|
||||
case "send_photo":
|
||||
return t.sendPhoto(ctx, params)
|
||||
default:
|
||||
return nil, errors.New("unknown action: " + action)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) sendMessage(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
chatID := params["chat_id"].(string)
|
||||
text := params["text"].(string)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if parseMode, ok := params["parse_mode"].(string); ok {
|
||||
payload["parse_mode"] = parseMode
|
||||
}
|
||||
|
||||
return t.apiCall(ctx, "sendMessage", payload)
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) sendDocument(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
chatID := params["chat_id"].(string)
|
||||
document := params["document"].(string)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
"document": document,
|
||||
}
|
||||
|
||||
if caption, ok := params["caption"].(string); ok {
|
||||
payload["caption"] = caption
|
||||
}
|
||||
|
||||
return t.apiCall(ctx, "sendDocument", payload)
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) sendPhoto(ctx context.Context, params map[string]interface{}) (interface{}, error) {
|
||||
chatID := params["chat_id"].(string)
|
||||
photo := params["photo"].(string)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
"photo": photo,
|
||||
}
|
||||
|
||||
if caption, ok := params["caption"].(string); ok {
|
||||
payload["caption"] = caption
|
||||
}
|
||||
|
||||
return t.apiCall(ctx, "sendPhoto", payload)
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) apiCall(ctx context.Context, method string, payload map[string]interface{}) (interface{}, error) {
|
||||
url := fmt.Sprintf("https://api.telegram.org/bot%s/%s", t.cfg.BotToken, method)
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok, exists := result["ok"].(bool); exists && !ok {
|
||||
desc := "unknown error"
|
||||
if d, exists := result["description"].(string); exists {
|
||||
desc = d
|
||||
}
|
||||
return result, errors.New("Telegram API error: " + desc)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) SendFileFromBytes(ctx context.Context, chatID string, filename string, content []byte, caption string) (interface{}, error) {
|
||||
url := fmt.Sprintf("https://api.telegram.org/bot%s/sendDocument", t.cfg.BotToken)
|
||||
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
w.WriteField("chat_id", chatID)
|
||||
|
||||
if caption != "" {
|
||||
w.WriteField("caption", caption)
|
||||
}
|
||||
|
||||
fw, err := w.CreateFormFile("document", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fw.Write(content)
|
||||
|
||||
w.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, &b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TelegramConnector) GetChatID(chatIDOrUsername interface{}) string {
|
||||
switch v := chatIDOrUsername.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case float64:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
275
backend/internal/computer/connectors/webhook.go
Normal file
275
backend/internal/computer/connectors/webhook.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WebhookConfig struct {
|
||||
Timeout time.Duration
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
DefaultSecret string
|
||||
}
|
||||
|
||||
type WebhookConnector struct {
|
||||
cfg WebhookConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewWebhookConnector(cfg WebhookConfig) *WebhookConnector {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 3
|
||||
}
|
||||
if cfg.RetryDelay == 0 {
|
||||
cfg.RetryDelay = time.Second
|
||||
}
|
||||
|
||||
return &WebhookConnector{
|
||||
cfg: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) ID() string {
|
||||
return "webhook"
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) Name() string {
|
||||
return "Webhook"
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) Description() string {
|
||||
return "Send HTTP webhooks to external services"
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) GetActions() []Action {
|
||||
return []Action{
|
||||
{
|
||||
Name: "post",
|
||||
Description: "Send POST request",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "Webhook URL"},
|
||||
"body": map[string]interface{}{"type": "object", "description": "Request body (JSON)"},
|
||||
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
|
||||
"secret": map[string]interface{}{"type": "string", "description": "HMAC secret for signing"},
|
||||
},
|
||||
},
|
||||
Required: []string{"url"},
|
||||
},
|
||||
{
|
||||
Name: "get",
|
||||
Description: "Send GET request",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "Request URL"},
|
||||
"params": map[string]interface{}{"type": "object", "description": "Query parameters"},
|
||||
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
|
||||
},
|
||||
},
|
||||
Required: []string{"url"},
|
||||
},
|
||||
{
|
||||
Name: "put",
|
||||
Description: "Send PUT request",
|
||||
Schema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "Request URL"},
|
||||
"body": map[string]interface{}{"type": "object", "description": "Request body (JSON)"},
|
||||
"headers": map[string]interface{}{"type": "object", "description": "Custom headers"},
|
||||
},
|
||||
},
|
||||
Required: []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) Validate(params map[string]interface{}) error {
|
||||
urlStr, ok := params["url"].(string)
|
||||
if !ok {
|
||||
return errors.New("'url' is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return errors.New("URL must use http or https scheme")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) {
|
||||
switch action {
|
||||
case "post":
|
||||
return w.doRequest(ctx, "POST", params)
|
||||
case "get":
|
||||
return w.doRequest(ctx, "GET", params)
|
||||
case "put":
|
||||
return w.doRequest(ctx, "PUT", params)
|
||||
case "delete":
|
||||
return w.doRequest(ctx, "DELETE", params)
|
||||
case "patch":
|
||||
return w.doRequest(ctx, "PATCH", params)
|
||||
default:
|
||||
return nil, errors.New("unknown action: " + action)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) doRequest(ctx context.Context, method string, params map[string]interface{}) (interface{}, error) {
|
||||
urlStr := params["url"].(string)
|
||||
|
||||
if method == "GET" {
|
||||
if queryParams, ok := params["params"].(map[string]interface{}); ok {
|
||||
parsedURL, _ := url.Parse(urlStr)
|
||||
q := parsedURL.Query()
|
||||
for k, v := range queryParams {
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
parsedURL.RawQuery = q.Encode()
|
||||
urlStr = parsedURL.String()
|
||||
}
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
var bodyBytes []byte
|
||||
|
||||
if body, ok := params["body"]; ok && method != "GET" {
|
||||
var err error
|
||||
bodyBytes, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= w.cfg.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(w.cfg.RetryDelay * time.Duration(attempt))
|
||||
if bodyBytes != nil {
|
||||
bodyReader = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "GooSeek-Computer/1.0")
|
||||
|
||||
if headers, ok := params["headers"].(map[string]interface{}); ok {
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
if bodyBytes != nil {
|
||||
secret := w.cfg.DefaultSecret
|
||||
if s, ok := params["secret"].(string); ok {
|
||||
secret = s
|
||||
}
|
||||
if secret != "" {
|
||||
signature := w.signPayload(bodyBytes, secret)
|
||||
req.Header.Set("X-Signature-256", "sha256="+signature)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := w.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"status_code": resp.StatusCode,
|
||||
"headers": w.headersToMap(resp.Header),
|
||||
}
|
||||
|
||||
var jsonBody interface{}
|
||||
if err := json.Unmarshal(respBody, &jsonBody); err == nil {
|
||||
result["body"] = jsonBody
|
||||
} else {
|
||||
result["body"] = string(respBody)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
result["success"] = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 500 {
|
||||
lastErr = fmt.Errorf("server error: %d", resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
result["success"] = false
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": lastErr.Error(),
|
||||
}, lastErr
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) signPayload(payload []byte, secret string) string {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write(payload)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) headersToMap(headers http.Header) map[string]string {
|
||||
result := make(map[string]string)
|
||||
for k, v := range headers {
|
||||
result[k] = strings.Join(v, ", ")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) PostJSON(ctx context.Context, webhookURL string, data interface{}) (interface{}, error) {
|
||||
return w.Execute(ctx, "post", map[string]interface{}{
|
||||
"url": webhookURL,
|
||||
"body": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WebhookConnector) GetJSON(ctx context.Context, webhookURL string, params map[string]interface{}) (interface{}, error) {
|
||||
return w.Execute(ctx, "get", map[string]interface{}{
|
||||
"url": webhookURL,
|
||||
"params": params,
|
||||
})
|
||||
}
|
||||
574
backend/internal/computer/executor.go
Normal file
574
backend/internal/computer/executor.go
Normal file
@@ -0,0 +1,574 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Executor struct {
|
||||
router *Router
|
||||
sandbox *SandboxManager
|
||||
maxWorkers int
|
||||
}
|
||||
|
||||
func NewExecutor(router *Router, maxWorkers int) *Executor {
|
||||
if maxWorkers <= 0 {
|
||||
maxWorkers = 5
|
||||
}
|
||||
return &Executor{
|
||||
router: router,
|
||||
maxWorkers: maxWorkers,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Executor) SetSandbox(sandbox *SandboxManager) {
|
||||
e.sandbox = sandbox
|
||||
}
|
||||
|
||||
func (e *Executor) ExecuteGroup(ctx context.Context, tasks []SubTask, budget float64) ([]ExecutionResult, error) {
|
||||
results := make([]ExecutionResult, len(tasks))
|
||||
var mu sync.Mutex
|
||||
|
||||
perTaskBudget := budget / float64(len(tasks))
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(e.maxWorkers)
|
||||
|
||||
for i, task := range tasks {
|
||||
i, task := i, task
|
||||
g.Go(func() error {
|
||||
result, err := e.ExecuteTask(gctx, &task, perTaskBudget)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
results[i] = ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Error: err,
|
||||
}
|
||||
} else {
|
||||
results[i] = *result
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (e *Executor) ExecuteTask(ctx context.Context, task *SubTask, budget float64) (*ExecutionResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
client, spec, err := e.router.Route(task, budget)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("routing failed: %w", err)
|
||||
}
|
||||
|
||||
task.ModelID = spec.ID
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
|
||||
var result *ExecutionResult
|
||||
|
||||
switch task.Type {
|
||||
case TaskResearch:
|
||||
result, err = e.executeResearch(ctx, client, task)
|
||||
case TaskCode:
|
||||
result, err = e.executeCode(ctx, client, task)
|
||||
case TaskAnalysis:
|
||||
result, err = e.executeAnalysis(ctx, client, task)
|
||||
case TaskDesign:
|
||||
result, err = e.executeDesign(ctx, client, task)
|
||||
case TaskDeploy:
|
||||
result, err = e.executeDeploy(ctx, client, task)
|
||||
case TaskReport:
|
||||
result, err = e.executeReport(ctx, client, task)
|
||||
case TaskCommunicate:
|
||||
result, err = e.executeCommunicate(ctx, client, task)
|
||||
case TaskTransform:
|
||||
result, err = e.executeTransform(ctx, client, task)
|
||||
case TaskValidate:
|
||||
result, err = e.executeValidate(ctx, client, task)
|
||||
default:
|
||||
result, err = e.executeGeneric(ctx, client, task)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result.Duration = time.Since(startTime)
|
||||
result.Cost = e.router.EstimateCost(task, 1000, 500)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeResearch(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
prompt := fmt.Sprintf(`You are a research assistant. Complete this research task:
|
||||
|
||||
Task: %s
|
||||
|
||||
Additional context: %v
|
||||
|
||||
Provide a comprehensive research result with:
|
||||
1. Key findings
|
||||
2. Sources/references
|
||||
3. Summary
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"findings": ["finding 1", "finding 2"],
|
||||
"sources": ["source 1", "source 2"],
|
||||
"summary": "...",
|
||||
"data": {}
|
||||
}`, task.Description, task.Input)
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeCode(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputContext := ""
|
||||
if task.Input != nil {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
inputContext = fmt.Sprintf("\n\nContext from previous tasks:\n%s", string(inputJSON))
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`You are an expert programmer. Complete this coding task:
|
||||
|
||||
Task: %s%s
|
||||
|
||||
Requirements:
|
||||
1. Write clean, production-ready code
|
||||
2. Include error handling
|
||||
3. Add necessary imports
|
||||
4. Follow best practices
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"language": "python",
|
||||
"code": "...",
|
||||
"filename": "main.py",
|
||||
"dependencies": ["package1", "package2"],
|
||||
"explanation": "..."
|
||||
}`, task.Description, inputContext)
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 8192},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
var artifacts []Artifact
|
||||
if code, ok := output["code"].(string); ok {
|
||||
filename := "main.py"
|
||||
if fn, ok := output["filename"].(string); ok {
|
||||
filename = fn
|
||||
}
|
||||
artifacts = append(artifacts, Artifact{
|
||||
ID: uuid.New().String(),
|
||||
TaskID: task.ID,
|
||||
Type: ArtifactTypeCode,
|
||||
Name: filename,
|
||||
Content: []byte(code),
|
||||
Size: int64(len(code)),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
Artifacts: artifacts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeAnalysis(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`You are a data analyst. Analyze this data/information:
|
||||
|
||||
Task: %s
|
||||
|
||||
Input data:
|
||||
%s
|
||||
|
||||
Provide:
|
||||
1. Key insights
|
||||
2. Patterns observed
|
||||
3. Recommendations
|
||||
4. Visualizations needed (describe)
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"insights": ["insight 1", "insight 2"],
|
||||
"patterns": ["pattern 1"],
|
||||
"recommendations": ["rec 1"],
|
||||
"visualizations": ["chart type 1"],
|
||||
"summary": "..."
|
||||
}`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeDesign(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`You are a software architect. Design a solution:
|
||||
|
||||
Task: %s
|
||||
|
||||
Context:
|
||||
%s
|
||||
|
||||
Provide:
|
||||
1. Architecture overview
|
||||
2. Components and their responsibilities
|
||||
3. Data flow
|
||||
4. Technology recommendations
|
||||
5. Implementation plan
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"architecture": "...",
|
||||
"components": [{"name": "...", "responsibility": "..."}],
|
||||
"dataFlow": "...",
|
||||
"technologies": ["tech1", "tech2"],
|
||||
"implementationSteps": ["step1", "step2"],
|
||||
"diagram": "mermaid diagram code"
|
||||
}`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeDeploy(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
if e.sandbox == nil {
|
||||
return e.executeGeneric(ctx, client, task)
|
||||
}
|
||||
|
||||
var code string
|
||||
if task.Input != nil {
|
||||
if c, ok := task.Input["code"].(string); ok {
|
||||
code = c
|
||||
}
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
return e.executeGeneric(ctx, client, task)
|
||||
}
|
||||
|
||||
sandbox, err := e.sandbox.Create(ctx, task.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create sandbox: %w", err)
|
||||
}
|
||||
defer e.sandbox.Destroy(ctx, sandbox)
|
||||
|
||||
result, err := e.sandbox.Execute(ctx, sandbox, code, "python")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sandbox execution failed: %w", err)
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"stdout": result.Stdout,
|
||||
"stderr": result.Stderr,
|
||||
"exitCode": result.ExitCode,
|
||||
"duration": result.Duration.String(),
|
||||
}
|
||||
|
||||
var artifacts []Artifact
|
||||
for name, content := range result.Files {
|
||||
artifacts = append(artifacts, Artifact{
|
||||
ID: uuid.New().String(),
|
||||
TaskID: task.ID,
|
||||
Type: ArtifactTypeFile,
|
||||
Name: name,
|
||||
Content: content,
|
||||
Size: int64(len(content)),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
Artifacts: artifacts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeReport(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`You are a report writer. Generate a comprehensive report:
|
||||
|
||||
Task: %s
|
||||
|
||||
Data/Context:
|
||||
%s
|
||||
|
||||
Create a well-structured report with:
|
||||
1. Executive Summary
|
||||
2. Key Findings
|
||||
3. Detailed Analysis
|
||||
4. Conclusions
|
||||
5. Recommendations
|
||||
|
||||
Use markdown formatting.`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 8192},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"report": response,
|
||||
"format": "markdown",
|
||||
"wordCount": len(strings.Fields(response)),
|
||||
}
|
||||
|
||||
artifacts := []Artifact{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
TaskID: task.ID,
|
||||
Type: ArtifactTypeReport,
|
||||
Name: "report.md",
|
||||
Content: []byte(response),
|
||||
MimeType: "text/markdown",
|
||||
Size: int64(len(response)),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
Artifacts: artifacts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeCommunicate(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`Generate a message/notification:
|
||||
|
||||
Task: %s
|
||||
|
||||
Context:
|
||||
%s
|
||||
|
||||
Create an appropriate message. Respond in JSON:
|
||||
{
|
||||
"subject": "...",
|
||||
"body": "...",
|
||||
"format": "text|html",
|
||||
"priority": "low|normal|high"
|
||||
}`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 2048},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
output["status"] = "prepared"
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeTransform(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`Transform data as requested:
|
||||
|
||||
Task: %s
|
||||
|
||||
Input data:
|
||||
%s
|
||||
|
||||
Perform the transformation and return the result in JSON:
|
||||
{
|
||||
"transformed": ...,
|
||||
"format": "...",
|
||||
"changes": ["change 1", "change 2"]
|
||||
}`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeValidate(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`Validate the following:
|
||||
|
||||
Task: %s
|
||||
|
||||
Data to validate:
|
||||
%s
|
||||
|
||||
Check for:
|
||||
1. Correctness
|
||||
2. Completeness
|
||||
3. Consistency
|
||||
4. Quality
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"valid": true|false,
|
||||
"score": 0-100,
|
||||
"issues": ["issue 1", "issue 2"],
|
||||
"suggestions": ["suggestion 1"],
|
||||
"summary": "..."
|
||||
}`, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 2048},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) executeGeneric(ctx context.Context, client llm.Client, task *SubTask) (*ExecutionResult, error) {
|
||||
inputJSON, _ := json.Marshal(task.Input)
|
||||
|
||||
prompt := fmt.Sprintf(`Complete this task:
|
||||
|
||||
Task type: %s
|
||||
Description: %s
|
||||
|
||||
Context:
|
||||
%s
|
||||
|
||||
Provide a comprehensive result in JSON format.`, task.Type, task.Description, string(inputJSON))
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}},
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := parseJSONOutput(response)
|
||||
if len(output) == 0 {
|
||||
output = map[string]interface{}{
|
||||
"result": response,
|
||||
}
|
||||
}
|
||||
|
||||
return &ExecutionResult{
|
||||
TaskID: task.ID,
|
||||
SubTaskID: task.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseJSONOutput(response string) map[string]interface{} {
|
||||
startIdx := strings.Index(response, "{")
|
||||
endIdx := strings.LastIndex(response, "}")
|
||||
|
||||
if startIdx == -1 || endIdx == -1 || endIdx <= startIdx {
|
||||
return map[string]interface{}{"raw": response}
|
||||
}
|
||||
|
||||
jsonStr := response[startIdx : endIdx+1]
|
||||
|
||||
var output map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &output); err != nil {
|
||||
return map[string]interface{}{"raw": response}
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
377
backend/internal/computer/memory.go
Normal file
377
backend/internal/computer/memory.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type MemoryStore struct {
|
||||
repo MemoryRepository
|
||||
cache map[string][]MemoryEntry
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemoryStore(repo MemoryRepository) *MemoryStore {
|
||||
return &MemoryStore{
|
||||
repo: repo,
|
||||
cache: make(map[string][]MemoryEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Store(ctx context.Context, userID string, entry *MemoryEntry) error {
|
||||
if entry.ID == "" {
|
||||
entry.ID = uuid.New().String()
|
||||
}
|
||||
entry.UserID = userID
|
||||
if entry.CreatedAt.IsZero() {
|
||||
entry.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
if m.repo != nil {
|
||||
if err := m.repo.Store(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.cache[userID] = append(m.cache[userID], *entry)
|
||||
if len(m.cache[userID]) > 1000 {
|
||||
m.cache[userID] = m.cache[userID][len(m.cache[userID])-500:]
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) StoreResult(ctx context.Context, userID, taskID, key string, value interface{}) error {
|
||||
valueJSON, _ := json.Marshal(value)
|
||||
|
||||
entry := &MemoryEntry{
|
||||
UserID: userID,
|
||||
TaskID: taskID,
|
||||
Key: key,
|
||||
Value: string(valueJSON),
|
||||
Type: MemoryTypeResult,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return m.Store(ctx, userID, entry)
|
||||
}
|
||||
|
||||
func (m *MemoryStore) StoreFact(ctx context.Context, userID, key string, value interface{}, tags []string) error {
|
||||
entry := &MemoryEntry{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MemoryTypeFact,
|
||||
Tags: tags,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return m.Store(ctx, userID, entry)
|
||||
}
|
||||
|
||||
func (m *MemoryStore) StorePreference(ctx context.Context, userID, key string, value interface{}) error {
|
||||
entry := &MemoryEntry{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MemoryTypePreference,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return m.Store(ctx, userID, entry)
|
||||
}
|
||||
|
||||
func (m *MemoryStore) StoreContext(ctx context.Context, userID, taskID, key string, value interface{}, ttl time.Duration) error {
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
|
||||
entry := &MemoryEntry{
|
||||
UserID: userID,
|
||||
TaskID: taskID,
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MemoryTypeContext,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
|
||||
return m.Store(ctx, userID, entry)
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Recall(ctx context.Context, userID string, query string, limit int) ([]MemoryEntry, error) {
|
||||
if m.repo != nil {
|
||||
entries, err := m.repo.Search(ctx, userID, query, limit)
|
||||
if err == nil && len(entries) > 0 {
|
||||
return entries, nil
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
cached := m.cache[userID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if len(cached) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
queryLower := strings.ToLower(query)
|
||||
queryTerms := strings.Fields(queryLower)
|
||||
|
||||
type scored struct {
|
||||
entry MemoryEntry
|
||||
score int
|
||||
}
|
||||
|
||||
var results []scored
|
||||
now := time.Now()
|
||||
|
||||
for _, entry := range cached {
|
||||
if entry.ExpiresAt != nil && entry.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
|
||||
score := 0
|
||||
|
||||
keyLower := strings.ToLower(entry.Key)
|
||||
for _, term := range queryTerms {
|
||||
if strings.Contains(keyLower, term) {
|
||||
score += 3
|
||||
}
|
||||
}
|
||||
|
||||
if valueStr, ok := entry.Value.(string); ok {
|
||||
valueLower := strings.ToLower(valueStr)
|
||||
for _, term := range queryTerms {
|
||||
if strings.Contains(valueLower, term) {
|
||||
score += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tag := range entry.Tags {
|
||||
tagLower := strings.ToLower(tag)
|
||||
for _, term := range queryTerms {
|
||||
if strings.Contains(tagLower, term) {
|
||||
score += 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if score > 0 {
|
||||
results = append(results, scored{entry: entry, score: score})
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(results)-1; i++ {
|
||||
for j := i + 1; j < len(results); j++ {
|
||||
if results[j].score > results[i].score {
|
||||
results[i], results[j] = results[j], results[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
|
||||
entries := make([]MemoryEntry, len(results))
|
||||
for i, r := range results {
|
||||
entries[i] = r.entry
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetByUser(ctx context.Context, userID string, limit int) ([]MemoryEntry, error) {
|
||||
if m.repo != nil {
|
||||
return m.repo.GetByUser(ctx, userID, limit)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
cached := m.cache[userID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if len(cached) > limit {
|
||||
return cached[len(cached)-limit:], nil
|
||||
}
|
||||
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetByTask(ctx context.Context, taskID string) ([]MemoryEntry, error) {
|
||||
if m.repo != nil {
|
||||
return m.repo.GetByTask(ctx, taskID)
|
||||
}
|
||||
|
||||
var result []MemoryEntry
|
||||
|
||||
m.mu.RLock()
|
||||
for _, entries := range m.cache {
|
||||
for _, e := range entries {
|
||||
if e.TaskID == taskID {
|
||||
result = append(result, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTaskContext(ctx context.Context, taskID string) (map[string]interface{}, error) {
|
||||
entries, err := m.GetByTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
context := make(map[string]interface{})
|
||||
for _, e := range entries {
|
||||
context[e.Key] = e.Value
|
||||
}
|
||||
|
||||
return context, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetUserContext(ctx context.Context, userID string) (map[string]interface{}, error) {
|
||||
entries, err := m.GetByUser(ctx, userID, 100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
context := make(map[string]interface{})
|
||||
for _, e := range entries {
|
||||
if e.Type == MemoryTypePreference || e.Type == MemoryTypeFact {
|
||||
context[e.Key] = e.Value
|
||||
}
|
||||
}
|
||||
|
||||
return context, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetPreferences(ctx context.Context, userID string) (map[string]interface{}, error) {
|
||||
entries, err := m.GetByUser(ctx, userID, 100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefs := make(map[string]interface{})
|
||||
for _, e := range entries {
|
||||
if e.Type == MemoryTypePreference {
|
||||
prefs[e.Key] = e.Value
|
||||
}
|
||||
}
|
||||
|
||||
return prefs, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFacts(ctx context.Context, userID string) ([]MemoryEntry, error) {
|
||||
entries, err := m.GetByUser(ctx, userID, 100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var facts []MemoryEntry
|
||||
for _, e := range entries {
|
||||
if e.Type == MemoryTypeFact {
|
||||
facts = append(facts, e)
|
||||
}
|
||||
}
|
||||
|
||||
return facts, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Delete(ctx context.Context, id string) error {
|
||||
if m.repo != nil {
|
||||
return m.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
for userID, entries := range m.cache {
|
||||
for i, e := range entries {
|
||||
if e.ID == id {
|
||||
m.cache[userID] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Clear(ctx context.Context, userID string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.cache, userID)
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) ClearTask(ctx context.Context, taskID string) error {
|
||||
m.mu.Lock()
|
||||
for userID, entries := range m.cache {
|
||||
var filtered []MemoryEntry
|
||||
for _, e := range entries {
|
||||
if e.TaskID != taskID {
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
m.cache[userID] = filtered
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Cleanup(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
m.mu.Lock()
|
||||
for userID, entries := range m.cache {
|
||||
var valid []MemoryEntry
|
||||
for _, e := range entries {
|
||||
if e.ExpiresAt == nil || e.ExpiresAt.After(now) {
|
||||
valid = append(valid, e)
|
||||
}
|
||||
}
|
||||
m.cache[userID] = valid
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Stats(userID string) map[string]int {
|
||||
m.mu.RLock()
|
||||
entries := m.cache[userID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
stats := map[string]int{
|
||||
"total": len(entries),
|
||||
"facts": 0,
|
||||
"preferences": 0,
|
||||
"context": 0,
|
||||
"results": 0,
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
switch e.Type {
|
||||
case MemoryTypeFact:
|
||||
stats["facts"]++
|
||||
case MemoryTypePreference:
|
||||
stats["preferences"]++
|
||||
case MemoryTypeContext:
|
||||
stats["context"]++
|
||||
case MemoryTypeResult:
|
||||
stats["results"]++
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
371
backend/internal/computer/planner.go
Normal file
371
backend/internal/computer/planner.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Planner struct {
|
||||
registry *llm.ModelRegistry
|
||||
}
|
||||
|
||||
func NewPlanner(registry *llm.ModelRegistry) *Planner {
|
||||
return &Planner{
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Planner) Plan(ctx context.Context, query string, memory map[string]interface{}) (*TaskPlan, error) {
|
||||
client, _, err := p.registry.GetBest(llm.CapReasoning)
|
||||
if err != nil {
|
||||
client, _, err = p.registry.GetBest(llm.CapCoding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no suitable model for planning: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
memoryContext := ""
|
||||
if len(memory) > 0 {
|
||||
memoryJSON, _ := json.Marshal(memory)
|
||||
memoryContext = fmt.Sprintf("\n\nUser context and memory:\n%s", string(memoryJSON))
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`You are a task planning AI. Analyze this query and create an execution plan.
|
||||
|
||||
Query: %s%s
|
||||
|
||||
Break this into subtasks. Each subtask should be:
|
||||
1. Atomic - one clear action
|
||||
2. Independent where possible (for parallel execution)
|
||||
3. Have clear dependencies when needed
|
||||
|
||||
Available task types:
|
||||
- research: Search web, gather information
|
||||
- code: Write/generate code
|
||||
- analysis: Analyze data, extract insights
|
||||
- design: Design architecture, create plans
|
||||
- deploy: Deploy applications, run code
|
||||
- monitor: Set up monitoring, tracking
|
||||
- report: Generate reports, summaries
|
||||
- communicate: Send emails, messages
|
||||
- transform: Convert data formats
|
||||
- validate: Check, verify results
|
||||
|
||||
For each subtask specify:
|
||||
- type: one of the task types above
|
||||
- description: what to do
|
||||
- dependencies: list of subtask IDs this depends on (empty if none)
|
||||
- capabilities: required AI capabilities (reasoning, coding, search, creative, fast, long_context, vision, math)
|
||||
|
||||
Respond in JSON format:
|
||||
{
|
||||
"summary": "Brief summary of the plan",
|
||||
"subtasks": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "research",
|
||||
"description": "Search for...",
|
||||
"dependencies": [],
|
||||
"capabilities": ["search"]
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "code",
|
||||
"description": "Write code to...",
|
||||
"dependencies": ["1"],
|
||||
"capabilities": ["coding"]
|
||||
}
|
||||
],
|
||||
"estimatedCost": 0.05,
|
||||
"estimatedTimeSeconds": 120
|
||||
}
|
||||
|
||||
Create 3-10 subtasks. Be specific and actionable.`, query, memoryContext)
|
||||
|
||||
messages := []llm.Message{
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
}
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: messages,
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return p.createDefaultPlan(query), nil
|
||||
}
|
||||
|
||||
plan, err := p.parsePlanResponse(response)
|
||||
if err != nil {
|
||||
return p.createDefaultPlan(query), nil
|
||||
}
|
||||
|
||||
plan.Query = query
|
||||
plan.ExecutionOrder = p.calculateExecutionOrder(plan.SubTasks)
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (p *Planner) parsePlanResponse(response string) (*TaskPlan, error) {
|
||||
jsonRegex := regexp.MustCompile(`\{[\s\S]*\}`)
|
||||
jsonMatch := jsonRegex.FindString(response)
|
||||
if jsonMatch == "" {
|
||||
return nil, fmt.Errorf("no JSON found in response")
|
||||
}
|
||||
|
||||
var rawPlan struct {
|
||||
Summary string `json:"summary"`
|
||||
EstimatedCost float64 `json:"estimatedCost"`
|
||||
EstimatedTimeSeconds int `json:"estimatedTimeSeconds"`
|
||||
SubTasks []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Dependencies []string `json:"dependencies"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
} `json:"subtasks"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonMatch), &rawPlan); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse plan JSON: %w", err)
|
||||
}
|
||||
|
||||
plan := &TaskPlan{
|
||||
Summary: rawPlan.Summary,
|
||||
EstimatedCost: rawPlan.EstimatedCost,
|
||||
EstimatedTime: rawPlan.EstimatedTimeSeconds,
|
||||
SubTasks: make([]SubTask, len(rawPlan.SubTasks)),
|
||||
}
|
||||
|
||||
for i, st := range rawPlan.SubTasks {
|
||||
caps := make([]llm.ModelCapability, len(st.Capabilities))
|
||||
for j, c := range st.Capabilities {
|
||||
caps[j] = llm.ModelCapability(c)
|
||||
}
|
||||
|
||||
plan.SubTasks[i] = SubTask{
|
||||
ID: st.ID,
|
||||
Type: TaskType(st.Type),
|
||||
Description: st.Description,
|
||||
Dependencies: st.Dependencies,
|
||||
RequiredCaps: caps,
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (p *Planner) calculateExecutionOrder(subTasks []SubTask) [][]string {
|
||||
taskMap := make(map[string]*SubTask)
|
||||
for i := range subTasks {
|
||||
taskMap[subTasks[i].ID] = &subTasks[i]
|
||||
}
|
||||
|
||||
inDegree := make(map[string]int)
|
||||
for _, st := range subTasks {
|
||||
if _, ok := inDegree[st.ID]; !ok {
|
||||
inDegree[st.ID] = 0
|
||||
}
|
||||
for _, dep := range st.Dependencies {
|
||||
inDegree[st.ID]++
|
||||
if _, ok := inDegree[dep]; !ok {
|
||||
inDegree[dep] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var order [][]string
|
||||
completed := make(map[string]bool)
|
||||
|
||||
for len(completed) < len(subTasks) {
|
||||
var wave []string
|
||||
|
||||
for _, st := range subTasks {
|
||||
if completed[st.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
canExecute := true
|
||||
for _, dep := range st.Dependencies {
|
||||
if !completed[dep] {
|
||||
canExecute = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if canExecute {
|
||||
wave = append(wave, st.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wave) == 0 {
|
||||
for _, st := range subTasks {
|
||||
if !completed[st.ID] {
|
||||
wave = append(wave, st.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range wave {
|
||||
completed[id] = true
|
||||
}
|
||||
|
||||
order = append(order, wave)
|
||||
}
|
||||
|
||||
return order
|
||||
}
|
||||
|
||||
func (p *Planner) createDefaultPlan(query string) *TaskPlan {
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
subTasks := []SubTask{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskResearch,
|
||||
Description: "Research and gather information about: " + query,
|
||||
Dependencies: []string{},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapSearch},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
},
|
||||
}
|
||||
|
||||
if strings.Contains(queryLower, "код") || strings.Contains(queryLower, "code") ||
|
||||
strings.Contains(queryLower, "приложение") || strings.Contains(queryLower, "app") ||
|
||||
strings.Contains(queryLower, "скрипт") || strings.Contains(queryLower, "script") {
|
||||
subTasks = append(subTasks, SubTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskDesign,
|
||||
Description: "Design architecture and structure",
|
||||
Dependencies: []string{subTasks[0].ID},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapReasoning},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
subTasks = append(subTasks, SubTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskCode,
|
||||
Description: "Generate code implementation",
|
||||
Dependencies: []string{subTasks[1].ID},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapCoding},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
}
|
||||
|
||||
if strings.Contains(queryLower, "отчёт") || strings.Contains(queryLower, "report") ||
|
||||
strings.Contains(queryLower, "анализ") || strings.Contains(queryLower, "analysis") {
|
||||
subTasks = append(subTasks, SubTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskAnalysis,
|
||||
Description: "Analyze gathered information",
|
||||
Dependencies: []string{subTasks[0].ID},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapReasoning},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
subTasks = append(subTasks, SubTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskReport,
|
||||
Description: "Generate comprehensive report",
|
||||
Dependencies: []string{subTasks[len(subTasks)-1].ID},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapCreative},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
}
|
||||
|
||||
if strings.Contains(queryLower, "email") || strings.Contains(queryLower, "письмо") ||
|
||||
strings.Contains(queryLower, "telegram") || strings.Contains(queryLower, "отправ") {
|
||||
subTasks = append(subTasks, SubTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: TaskCommunicate,
|
||||
Description: "Send notification/message",
|
||||
Dependencies: []string{subTasks[len(subTasks)-1].ID},
|
||||
RequiredCaps: []llm.ModelCapability{llm.CapFast},
|
||||
Status: StatusPending,
|
||||
MaxRetries: 3,
|
||||
})
|
||||
}
|
||||
|
||||
plan := &TaskPlan{
|
||||
Query: query,
|
||||
Summary: "Auto-generated plan for: " + query,
|
||||
SubTasks: subTasks,
|
||||
EstimatedCost: float64(len(subTasks)) * 0.01,
|
||||
EstimatedTime: len(subTasks) * 30,
|
||||
}
|
||||
|
||||
plan.ExecutionOrder = p.calculateExecutionOrder(subTasks)
|
||||
|
||||
return plan
|
||||
}
|
||||
|
||||
func (p *Planner) Replan(ctx context.Context, plan *TaskPlan, newContext string) (*TaskPlan, error) {
|
||||
completedTasks := make([]SubTask, 0)
|
||||
pendingTasks := make([]SubTask, 0)
|
||||
|
||||
for _, st := range plan.SubTasks {
|
||||
if st.Status == StatusCompleted {
|
||||
completedTasks = append(completedTasks, st)
|
||||
} else if st.Status == StatusPending || st.Status == StatusFailed {
|
||||
pendingTasks = append(pendingTasks, st)
|
||||
}
|
||||
}
|
||||
|
||||
completedJSON, _ := json.Marshal(completedTasks)
|
||||
pendingJSON, _ := json.Marshal(pendingTasks)
|
||||
|
||||
client, _, err := p.registry.GetBest(llm.CapReasoning)
|
||||
if err != nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`You need to replan a task based on new context.
|
||||
|
||||
Original query: %s
|
||||
|
||||
Completed subtasks:
|
||||
%s
|
||||
|
||||
Pending subtasks:
|
||||
%s
|
||||
|
||||
New context/feedback:
|
||||
%s
|
||||
|
||||
Adjust the plan. Keep completed tasks, modify or remove pending tasks as needed.
|
||||
Add new subtasks if the new context requires it.
|
||||
|
||||
Respond in the same JSON format as before.`, plan.Query, string(completedJSON), string(pendingJSON), newContext)
|
||||
|
||||
messages := []llm.Message{
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
}
|
||||
|
||||
response, err := client.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: messages,
|
||||
Options: llm.StreamOptions{MaxTokens: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
newPlan, err := p.parsePlanResponse(response)
|
||||
if err != nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
newPlan.Query = plan.Query
|
||||
newPlan.ExecutionOrder = p.calculateExecutionOrder(newPlan.SubTasks)
|
||||
|
||||
return newPlan, nil
|
||||
}
|
||||
244
backend/internal/computer/router.go
Normal file
244
backend/internal/computer/router.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
)
|
||||
|
||||
type RoutingRule struct {
|
||||
TaskType TaskType
|
||||
Preferred []llm.ModelCapability
|
||||
Fallback []string
|
||||
MaxCost float64
|
||||
MaxLatency int
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
registry *llm.ModelRegistry
|
||||
rules map[TaskType]RoutingRule
|
||||
}
|
||||
|
||||
func NewRouter(registry *llm.ModelRegistry) *Router {
|
||||
r := &Router{
|
||||
registry: registry,
|
||||
rules: make(map[TaskType]RoutingRule),
|
||||
}
|
||||
|
||||
r.rules[TaskResearch] = RoutingRule{
|
||||
TaskType: TaskResearch,
|
||||
Preferred: []llm.ModelCapability{llm.CapSearch, llm.CapLongContext},
|
||||
Fallback: []string{"gemini-1.5-pro", "gpt-4o"},
|
||||
MaxCost: 0.1,
|
||||
}
|
||||
|
||||
r.rules[TaskCode] = RoutingRule{
|
||||
TaskType: TaskCode,
|
||||
Preferred: []llm.ModelCapability{llm.CapCoding},
|
||||
Fallback: []string{"claude-3-sonnet", "claude-3-opus", "gpt-4o"},
|
||||
MaxCost: 0.2,
|
||||
}
|
||||
|
||||
r.rules[TaskAnalysis] = RoutingRule{
|
||||
TaskType: TaskAnalysis,
|
||||
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapMath},
|
||||
Fallback: []string{"claude-3-opus", "gpt-4o"},
|
||||
MaxCost: 0.15,
|
||||
}
|
||||
|
||||
r.rules[TaskDesign] = RoutingRule{
|
||||
TaskType: TaskDesign,
|
||||
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCreative},
|
||||
Fallback: []string{"claude-3-opus", "gpt-4o"},
|
||||
MaxCost: 0.15,
|
||||
}
|
||||
|
||||
r.rules[TaskDeploy] = RoutingRule{
|
||||
TaskType: TaskDeploy,
|
||||
Preferred: []llm.ModelCapability{llm.CapCoding, llm.CapFast},
|
||||
Fallback: []string{"claude-3-sonnet", "gpt-4o-mini"},
|
||||
MaxCost: 0.05,
|
||||
}
|
||||
|
||||
r.rules[TaskMonitor] = RoutingRule{
|
||||
TaskType: TaskMonitor,
|
||||
Preferred: []llm.ModelCapability{llm.CapFast},
|
||||
Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"},
|
||||
MaxCost: 0.02,
|
||||
}
|
||||
|
||||
r.rules[TaskReport] = RoutingRule{
|
||||
TaskType: TaskReport,
|
||||
Preferred: []llm.ModelCapability{llm.CapCreative, llm.CapLongContext},
|
||||
Fallback: []string{"claude-3-opus", "gpt-4o"},
|
||||
MaxCost: 0.1,
|
||||
}
|
||||
|
||||
r.rules[TaskCommunicate] = RoutingRule{
|
||||
TaskType: TaskCommunicate,
|
||||
Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCreative},
|
||||
Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"},
|
||||
MaxCost: 0.02,
|
||||
}
|
||||
|
||||
r.rules[TaskTransform] = RoutingRule{
|
||||
TaskType: TaskTransform,
|
||||
Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCoding},
|
||||
Fallback: []string{"gpt-4o-mini", "claude-3-sonnet"},
|
||||
MaxCost: 0.03,
|
||||
}
|
||||
|
||||
r.rules[TaskValidate] = RoutingRule{
|
||||
TaskType: TaskValidate,
|
||||
Preferred: []llm.ModelCapability{llm.CapReasoning},
|
||||
Fallback: []string{"gpt-4o", "claude-3-sonnet"},
|
||||
MaxCost: 0.05,
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) Route(task *SubTask, budget float64) (llm.Client, llm.ModelSpec, error) {
|
||||
if task.ModelID != "" {
|
||||
client, spec, err := r.registry.GetByID(task.ModelID)
|
||||
if err == nil && spec.CostPer1K <= budget {
|
||||
return client, spec, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(task.RequiredCaps) > 0 {
|
||||
for _, cap := range task.RequiredCaps {
|
||||
client, spec, err := r.registry.GetBest(cap)
|
||||
if err == nil && spec.CostPer1K <= budget {
|
||||
return client, spec, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rule, ok := r.rules[task.Type]
|
||||
if ok {
|
||||
for _, cap := range rule.Preferred {
|
||||
client, spec, err := r.registry.GetBest(cap)
|
||||
if err == nil && spec.CostPer1K <= budget {
|
||||
return client, spec, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, modelID := range rule.Fallback {
|
||||
client, spec, err := r.registry.GetByID(modelID)
|
||||
if err == nil && spec.CostPer1K <= budget {
|
||||
return client, spec, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
models := r.registry.GetAll()
|
||||
if len(models) == 0 {
|
||||
return nil, llm.ModelSpec{}, errors.New("no models available")
|
||||
}
|
||||
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
return models[i].CostPer1K < models[j].CostPer1K
|
||||
})
|
||||
|
||||
for _, spec := range models {
|
||||
if spec.CostPer1K <= budget {
|
||||
client, err := r.registry.GetClient(spec.ID)
|
||||
if err == nil {
|
||||
return client, spec, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client, err := r.registry.GetClient(models[0].ID)
|
||||
if err != nil {
|
||||
return nil, llm.ModelSpec{}, err
|
||||
}
|
||||
return client, models[0], nil
|
||||
}
|
||||
|
||||
func (r *Router) RouteMultiple(task *SubTask, count int, budget float64) ([]llm.Client, []llm.ModelSpec, error) {
|
||||
var clients []llm.Client
|
||||
var specs []llm.ModelSpec
|
||||
|
||||
usedModels := make(map[string]bool)
|
||||
perModelBudget := budget / float64(count)
|
||||
|
||||
rule, ok := r.rules[task.Type]
|
||||
if !ok {
|
||||
rule = RoutingRule{
|
||||
Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCoding, llm.CapFast},
|
||||
}
|
||||
}
|
||||
|
||||
for _, cap := range rule.Preferred {
|
||||
if len(clients) >= count {
|
||||
break
|
||||
}
|
||||
|
||||
models := r.registry.GetAllWithCapability(cap)
|
||||
for _, spec := range models {
|
||||
if len(clients) >= count {
|
||||
break
|
||||
}
|
||||
if usedModels[spec.ID] {
|
||||
continue
|
||||
}
|
||||
if spec.CostPer1K > perModelBudget {
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := r.registry.GetClient(spec.ID)
|
||||
if err == nil {
|
||||
clients = append(clients, client)
|
||||
specs = append(specs, spec)
|
||||
usedModels[spec.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(clients) < count {
|
||||
models := r.registry.GetAll()
|
||||
for _, spec := range models {
|
||||
if len(clients) >= count {
|
||||
break
|
||||
}
|
||||
if usedModels[spec.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := r.registry.GetClient(spec.ID)
|
||||
if err == nil {
|
||||
clients = append(clients, client)
|
||||
specs = append(specs, spec)
|
||||
usedModels[spec.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(clients) == 0 {
|
||||
return nil, nil, errors.New("no models available for consensus")
|
||||
}
|
||||
|
||||
return clients, specs, nil
|
||||
}
|
||||
|
||||
func (r *Router) SetRule(taskType TaskType, rule RoutingRule) {
|
||||
r.rules[taskType] = rule
|
||||
}
|
||||
|
||||
func (r *Router) GetRule(taskType TaskType) (RoutingRule, bool) {
|
||||
rule, ok := r.rules[taskType]
|
||||
return rule, ok
|
||||
}
|
||||
|
||||
func (r *Router) EstimateCost(task *SubTask, inputTokens, outputTokens int) float64 {
|
||||
_, spec, err := r.Route(task, 1.0)
|
||||
if err != nil {
|
||||
return 0.01
|
||||
}
|
||||
|
||||
totalTokens := inputTokens + outputTokens
|
||||
return spec.CostPer1K * float64(totalTokens) / 1000.0
|
||||
}
|
||||
431
backend/internal/computer/sandbox.go
Normal file
431
backend/internal/computer/sandbox.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type SandboxConfig struct {
|
||||
Image string
|
||||
Timeout time.Duration
|
||||
MemoryLimit string
|
||||
CPULimit string
|
||||
NetworkMode string
|
||||
WorkDir string
|
||||
MaxFileSize int64
|
||||
AllowNetwork bool
|
||||
}
|
||||
|
||||
func DefaultSandboxConfig() SandboxConfig {
|
||||
return SandboxConfig{
|
||||
Image: "gooseek/sandbox:latest",
|
||||
Timeout: 5 * time.Minute,
|
||||
MemoryLimit: "512m",
|
||||
CPULimit: "1.0",
|
||||
NetworkMode: "none",
|
||||
WorkDir: "/workspace",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowNetwork: false,
|
||||
}
|
||||
}
|
||||
|
||||
type Sandbox struct {
|
||||
ID string
|
||||
ContainerID string
|
||||
WorkDir string
|
||||
Status string
|
||||
TaskID string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SandboxManager struct {
|
||||
cfg SandboxConfig
|
||||
sandboxes map[string]*Sandbox
|
||||
mu sync.RWMutex
|
||||
useDocker bool
|
||||
}
|
||||
|
||||
func NewSandboxManager(cfg SandboxConfig) *SandboxManager {
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = 5 * time.Minute
|
||||
}
|
||||
if cfg.MemoryLimit == "" {
|
||||
cfg.MemoryLimit = "512m"
|
||||
}
|
||||
if cfg.WorkDir == "" {
|
||||
cfg.WorkDir = "/workspace"
|
||||
}
|
||||
|
||||
useDocker := isDockerAvailable()
|
||||
|
||||
return &SandboxManager{
|
||||
cfg: cfg,
|
||||
sandboxes: make(map[string]*Sandbox),
|
||||
useDocker: useDocker,
|
||||
}
|
||||
}
|
||||
|
||||
func isDockerAvailable() bool {
|
||||
cmd := exec.Command("docker", "version")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) Create(ctx context.Context, taskID string) (*Sandbox, error) {
|
||||
sandboxID := uuid.New().String()[:8]
|
||||
|
||||
sandbox := &Sandbox{
|
||||
ID: sandboxID,
|
||||
TaskID: taskID,
|
||||
Status: "creating",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if sm.useDocker {
|
||||
workDir, err := os.MkdirTemp("", fmt.Sprintf("sandbox-%s-", sandboxID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
sandbox.WorkDir = workDir
|
||||
|
||||
args := []string{
|
||||
"create",
|
||||
"--name", fmt.Sprintf("gooseek-sandbox-%s", sandboxID),
|
||||
"-v", fmt.Sprintf("%s:%s", workDir, sm.cfg.WorkDir),
|
||||
"-w", sm.cfg.WorkDir,
|
||||
"--memory", sm.cfg.MemoryLimit,
|
||||
"--cpus", sm.cfg.CPULimit,
|
||||
}
|
||||
|
||||
if !sm.cfg.AllowNetwork {
|
||||
args = append(args, "--network", "none")
|
||||
}
|
||||
|
||||
args = append(args, sm.cfg.Image, "tail", "-f", "/dev/null")
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
os.RemoveAll(workDir)
|
||||
return nil, fmt.Errorf("failed to create container: %w - %s", err, string(output))
|
||||
}
|
||||
|
||||
sandbox.ContainerID = strings.TrimSpace(string(output))
|
||||
|
||||
startCmd := exec.CommandContext(ctx, "docker", "start", sandbox.ContainerID)
|
||||
if err := startCmd.Run(); err != nil {
|
||||
sm.cleanupContainer(sandbox)
|
||||
return nil, fmt.Errorf("failed to start container: %w", err)
|
||||
}
|
||||
} else {
|
||||
workDir, err := os.MkdirTemp("", fmt.Sprintf("sandbox-%s-", sandboxID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
sandbox.WorkDir = workDir
|
||||
}
|
||||
|
||||
sandbox.Status = "running"
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.sandboxes[sandboxID] = sandbox
|
||||
sm.mu.Unlock()
|
||||
|
||||
return sandbox, nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) Execute(ctx context.Context, sandbox *Sandbox, code string, lang string) (*SandboxResult, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, sm.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
filename, err := sm.writeCodeFile(sandbox, code, lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
if sm.useDocker {
|
||||
runCmd := sm.getRunCommand(lang, filename)
|
||||
cmd = exec.CommandContext(ctx, "docker", "exec", sandbox.ContainerID, "sh", "-c", runCmd)
|
||||
} else {
|
||||
cmd = sm.getLocalCommand(ctx, lang, filepath.Join(sandbox.WorkDir, filename))
|
||||
}
|
||||
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
return &SandboxResult{
|
||||
Stderr: "Execution timeout exceeded",
|
||||
ExitCode: -1,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
files, _ := sm.collectOutputFiles(sandbox)
|
||||
|
||||
return &SandboxResult{
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
ExitCode: exitCode,
|
||||
Files: files,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) RunCommand(ctx context.Context, sandbox *Sandbox, command string) (*SandboxResult, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, sm.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
if sm.useDocker {
|
||||
cmd = exec.CommandContext(ctx, "docker", "exec", sandbox.ContainerID, "sh", "-c", command)
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, "sh", "-c", command)
|
||||
cmd.Dir = sandbox.WorkDir
|
||||
}
|
||||
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
}
|
||||
}
|
||||
|
||||
return &SandboxResult{
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
ExitCode: exitCode,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) WriteFile(ctx context.Context, sandbox *Sandbox, path string, content []byte) error {
|
||||
if int64(len(content)) > sm.cfg.MaxFileSize {
|
||||
return fmt.Errorf("file size exceeds limit: %d > %d", len(content), sm.cfg.MaxFileSize)
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(sandbox.WorkDir, path)
|
||||
dir := filepath.Dir(fullPath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(fullPath, content, 0644)
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) ReadFile(ctx context.Context, sandbox *Sandbox, path string) ([]byte, error) {
|
||||
fullPath := filepath.Join(sandbox.WorkDir, path)
|
||||
return os.ReadFile(fullPath)
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) Destroy(ctx context.Context, sandbox *Sandbox) error {
|
||||
sm.mu.Lock()
|
||||
delete(sm.sandboxes, sandbox.ID)
|
||||
sm.mu.Unlock()
|
||||
|
||||
if sm.useDocker && sandbox.ContainerID != "" {
|
||||
sm.cleanupContainer(sandbox)
|
||||
}
|
||||
|
||||
if sandbox.WorkDir != "" {
|
||||
os.RemoveAll(sandbox.WorkDir)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) cleanupContainer(sandbox *Sandbox) {
|
||||
exec.Command("docker", "stop", sandbox.ContainerID).Run()
|
||||
exec.Command("docker", "rm", "-f", sandbox.ContainerID).Run()
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) writeCodeFile(sandbox *Sandbox, code string, lang string) (string, error) {
|
||||
var filename string
|
||||
switch lang {
|
||||
case "python", "py":
|
||||
filename = "main.py"
|
||||
case "javascript", "js", "node":
|
||||
filename = "main.js"
|
||||
case "typescript", "ts":
|
||||
filename = "main.ts"
|
||||
case "go", "golang":
|
||||
filename = "main.go"
|
||||
case "bash", "sh", "shell":
|
||||
filename = "script.sh"
|
||||
case "ruby", "rb":
|
||||
filename = "main.rb"
|
||||
default:
|
||||
filename = "main.txt"
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(sandbox.WorkDir, filename)
|
||||
if err := os.WriteFile(fullPath, []byte(code), 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to write code file: %w", err)
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) getRunCommand(lang, filename string) string {
|
||||
switch lang {
|
||||
case "python", "py":
|
||||
return fmt.Sprintf("python3 %s/%s", sm.cfg.WorkDir, filename)
|
||||
case "javascript", "js", "node":
|
||||
return fmt.Sprintf("node %s/%s", sm.cfg.WorkDir, filename)
|
||||
case "typescript", "ts":
|
||||
return fmt.Sprintf("npx ts-node %s/%s", sm.cfg.WorkDir, filename)
|
||||
case "go", "golang":
|
||||
return fmt.Sprintf("go run %s/%s", sm.cfg.WorkDir, filename)
|
||||
case "bash", "sh", "shell":
|
||||
return fmt.Sprintf("bash %s/%s", sm.cfg.WorkDir, filename)
|
||||
case "ruby", "rb":
|
||||
return fmt.Sprintf("ruby %s/%s", sm.cfg.WorkDir, filename)
|
||||
default:
|
||||
return fmt.Sprintf("cat %s/%s", sm.cfg.WorkDir, filename)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) getLocalCommand(ctx context.Context, lang, filepath string) *exec.Cmd {
|
||||
switch lang {
|
||||
case "python", "py":
|
||||
return exec.CommandContext(ctx, "python3", filepath)
|
||||
case "javascript", "js", "node":
|
||||
return exec.CommandContext(ctx, "node", filepath)
|
||||
case "go", "golang":
|
||||
return exec.CommandContext(ctx, "go", "run", filepath)
|
||||
case "bash", "sh", "shell":
|
||||
return exec.CommandContext(ctx, "bash", filepath)
|
||||
case "ruby", "rb":
|
||||
return exec.CommandContext(ctx, "ruby", filepath)
|
||||
default:
|
||||
return exec.CommandContext(ctx, "cat", filepath)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) collectOutputFiles(sandbox *Sandbox) (map[string][]byte, error) {
|
||||
files := make(map[string][]byte)
|
||||
|
||||
err := filepath.Walk(sandbox.WorkDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(sandbox.WorkDir, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relPath, "main.") || strings.HasPrefix(relPath, "script.") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if info.Size() > sm.cfg.MaxFileSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
files[relPath] = content
|
||||
return nil
|
||||
})
|
||||
|
||||
return files, err
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) ListSandboxes() []*Sandbox {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
result := make([]*Sandbox, 0, len(sm.sandboxes))
|
||||
for _, s := range sm.sandboxes {
|
||||
result = append(result, s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) GetSandbox(id string) (*Sandbox, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
s, ok := sm.sandboxes[id]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) CopyToContainer(ctx context.Context, sandbox *Sandbox, src string, dst string) error {
|
||||
if !sm.useDocker {
|
||||
srcData, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sm.WriteFile(ctx, sandbox, dst, srcData)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "cp", src, fmt.Sprintf("%s:%s", sandbox.ContainerID, dst))
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) CopyFromContainer(ctx context.Context, sandbox *Sandbox, src string, dst string) error {
|
||||
if !sm.useDocker {
|
||||
srcPath := filepath.Join(sandbox.WorkDir, src)
|
||||
srcData, err := os.ReadFile(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, srcData, 0644)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "cp", fmt.Sprintf("%s:%s", sandbox.ContainerID, src), dst)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (sm *SandboxManager) StreamLogs(ctx context.Context, sandbox *Sandbox) (io.ReadCloser, error) {
|
||||
if !sm.useDocker {
|
||||
return nil, fmt.Errorf("streaming not supported without Docker")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "logs", "-f", sandbox.ContainerID)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stdout, nil
|
||||
}
|
||||
386
backend/internal/computer/scheduler.go
Normal file
386
backend/internal/computer/scheduler.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
taskRepo TaskRepository
|
||||
computer *Computer
|
||||
cron *cron.Cron
|
||||
jobs map[string]cron.EntryID
|
||||
running map[string]bool
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewScheduler(taskRepo TaskRepository, computer *Computer) *Scheduler {
|
||||
return &Scheduler{
|
||||
taskRepo: taskRepo,
|
||||
computer: computer,
|
||||
cron: cron.New(cron.WithSeconds()),
|
||||
jobs: make(map[string]cron.EntryID),
|
||||
running: make(map[string]bool),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start(ctx context.Context) {
|
||||
s.cron.Start()
|
||||
|
||||
go s.pollScheduledTasks(ctx)
|
||||
|
||||
log.Println("[Scheduler] Started")
|
||||
}
|
||||
|
||||
func (s *Scheduler) Stop() {
|
||||
close(s.stopCh)
|
||||
s.cron.Stop()
|
||||
log.Println("[Scheduler] Stopped")
|
||||
}
|
||||
|
||||
func (s *Scheduler) pollScheduledTasks(ctx context.Context) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.loadScheduledTasks(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkAndExecute(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) loadScheduledTasks(ctx context.Context) {
|
||||
tasks, err := s.taskRepo.GetScheduled(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] Failed to load scheduled tasks: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
if task.Schedule != nil && task.Schedule.Enabled {
|
||||
s.scheduleTask(&task)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[Scheduler] Loaded %d scheduled tasks", len(tasks))
|
||||
}
|
||||
|
||||
func (s *Scheduler) scheduleTask(task *ComputerTask) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if oldID, exists := s.jobs[task.ID]; exists {
|
||||
s.cron.Remove(oldID)
|
||||
}
|
||||
|
||||
if task.Schedule == nil || !task.Schedule.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
var entryID cron.EntryID
|
||||
var err error
|
||||
|
||||
switch task.Schedule.Type {
|
||||
case "cron":
|
||||
if task.Schedule.CronExpr == "" {
|
||||
return nil
|
||||
}
|
||||
entryID, err = s.cron.AddFunc(task.Schedule.CronExpr, func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
case "interval":
|
||||
if task.Schedule.Interval <= 0 {
|
||||
return nil
|
||||
}
|
||||
cronExpr := s.intervalToCron(task.Schedule.Interval)
|
||||
entryID, err = s.cron.AddFunc(cronExpr, func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
case "once":
|
||||
go func() {
|
||||
if task.Schedule.NextRun.After(time.Now()) {
|
||||
time.Sleep(time.Until(task.Schedule.NextRun))
|
||||
}
|
||||
s.executeScheduledTask(task.ID)
|
||||
}()
|
||||
return nil
|
||||
|
||||
case "daily":
|
||||
entryID, err = s.cron.AddFunc("0 0 9 * * *", func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
case "hourly":
|
||||
entryID, err = s.cron.AddFunc("0 0 * * * *", func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
case "weekly":
|
||||
entryID, err = s.cron.AddFunc("0 0 9 * * 1", func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
case "monthly":
|
||||
entryID, err = s.cron.AddFunc("0 0 9 1 * *", func() {
|
||||
s.executeScheduledTask(task.ID)
|
||||
})
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] Failed to schedule task %s: %v", task.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.jobs[task.ID] = entryID
|
||||
log.Printf("[Scheduler] Scheduled task %s with type %s", task.ID, task.Schedule.Type)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) intervalToCron(seconds int) string {
|
||||
if seconds < 60 {
|
||||
return "*/30 * * * * *"
|
||||
}
|
||||
if seconds < 3600 {
|
||||
minutes := seconds / 60
|
||||
return "0 */" + itoa(minutes) + " * * * *"
|
||||
}
|
||||
if seconds < 86400 {
|
||||
hours := seconds / 3600
|
||||
return "0 0 */" + itoa(hours) + " * * *"
|
||||
}
|
||||
return "0 0 0 * * *"
|
||||
}
|
||||
|
||||
func itoa(i int) string {
|
||||
if i < 10 {
|
||||
return string(rune('0' + i))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Scheduler) executeScheduledTask(taskID string) {
|
||||
s.mu.Lock()
|
||||
if s.running[taskID] {
|
||||
s.mu.Unlock()
|
||||
log.Printf("[Scheduler] Task %s is already running, skipping", taskID)
|
||||
return
|
||||
}
|
||||
s.running[taskID] = true
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.running, taskID)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
task, err := s.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] Failed to get task %s: %v", taskID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if task.Schedule != nil {
|
||||
if task.Schedule.ExpiresAt != nil && time.Now().After(*task.Schedule.ExpiresAt) {
|
||||
log.Printf("[Scheduler] Task %s has expired, removing", taskID)
|
||||
s.Cancel(taskID)
|
||||
return
|
||||
}
|
||||
|
||||
if task.Schedule.MaxRuns > 0 && task.Schedule.RunCount >= task.Schedule.MaxRuns {
|
||||
log.Printf("[Scheduler] Task %s reached max runs (%d), removing", taskID, task.Schedule.MaxRuns)
|
||||
s.Cancel(taskID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[Scheduler] Executing scheduled task %s (run #%d)", taskID, task.RunCount+1)
|
||||
|
||||
_, err = s.computer.Execute(ctx, task.UserID, task.Query, ExecuteOptions{
|
||||
Async: false,
|
||||
Context: task.Memory,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] Task %s execution failed: %v", taskID, err)
|
||||
} else {
|
||||
log.Printf("[Scheduler] Task %s completed successfully", taskID)
|
||||
}
|
||||
|
||||
task.RunCount++
|
||||
if task.Schedule != nil {
|
||||
task.Schedule.RunCount = task.RunCount
|
||||
task.Schedule.NextRun = s.calculateNextRun(task.Schedule)
|
||||
task.NextRunAt = &task.Schedule.NextRun
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
if err := s.taskRepo.Update(ctx, task); err != nil {
|
||||
log.Printf("[Scheduler] Failed to update task %s: %v", taskID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) calculateNextRun(schedule *Schedule) time.Time {
|
||||
switch schedule.Type {
|
||||
case "interval":
|
||||
return time.Now().Add(time.Duration(schedule.Interval) * time.Second)
|
||||
case "hourly":
|
||||
return time.Now().Add(time.Hour).Truncate(time.Hour)
|
||||
case "daily":
|
||||
next := time.Now().Add(24 * time.Hour)
|
||||
return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location())
|
||||
case "weekly":
|
||||
next := time.Now().Add(7 * 24 * time.Hour)
|
||||
return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location())
|
||||
case "monthly":
|
||||
next := time.Now().AddDate(0, 1, 0)
|
||||
return time.Date(next.Year(), next.Month(), 1, 9, 0, 0, 0, next.Location())
|
||||
default:
|
||||
return time.Now().Add(time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) checkAndExecute(ctx context.Context) {
|
||||
tasks, err := s.taskRepo.GetScheduled(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, task := range tasks {
|
||||
if task.NextRunAt != nil && task.NextRunAt.Before(now) {
|
||||
if task.Schedule != nil && task.Schedule.Enabled {
|
||||
go s.executeScheduledTask(task.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Schedule(taskID string, schedule Schedule) error {
|
||||
ctx := context.Background()
|
||||
task, err := s.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task.Schedule = &schedule
|
||||
task.Schedule.Enabled = true
|
||||
task.Schedule.NextRun = s.calculateNextRun(&schedule)
|
||||
task.NextRunAt = &task.Schedule.NextRun
|
||||
task.Status = StatusScheduled
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
if err := s.taskRepo.Update(ctx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.scheduleTask(task)
|
||||
}
|
||||
|
||||
func (s *Scheduler) Cancel(taskID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if entryID, exists := s.jobs[taskID]; exists {
|
||||
s.cron.Remove(entryID)
|
||||
delete(s.jobs, taskID)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
task, err := s.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Schedule != nil {
|
||||
task.Schedule.Enabled = false
|
||||
}
|
||||
task.Status = StatusCancelled
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
return s.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (s *Scheduler) Pause(taskID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if entryID, exists := s.jobs[taskID]; exists {
|
||||
s.cron.Remove(entryID)
|
||||
delete(s.jobs, taskID)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
task, err := s.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Schedule != nil {
|
||||
task.Schedule.Enabled = false
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
return s.taskRepo.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (s *Scheduler) Resume(taskID string) error {
|
||||
ctx := context.Background()
|
||||
task, err := s.taskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Schedule != nil {
|
||||
task.Schedule.Enabled = true
|
||||
task.Schedule.NextRun = s.calculateNextRun(task.Schedule)
|
||||
task.NextRunAt = &task.Schedule.NextRun
|
||||
}
|
||||
task.Status = StatusScheduled
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
if err := s.taskRepo.Update(ctx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.scheduleTask(task)
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetScheduledTasks() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]string, 0, len(s.jobs))
|
||||
for taskID := range s.jobs {
|
||||
result = append(result, taskID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Scheduler) IsRunning(taskID string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.running[taskID]
|
||||
}
|
||||
376
backend/internal/computer/types.go
Normal file
376
backend/internal/computer/types.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package computer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
)
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
StatusPending TaskStatus = "pending"
|
||||
StatusPlanning TaskStatus = "planning"
|
||||
StatusExecuting TaskStatus = "executing"
|
||||
StatusWaiting TaskStatus = "waiting_user"
|
||||
StatusCompleted TaskStatus = "completed"
|
||||
StatusFailed TaskStatus = "failed"
|
||||
StatusCancelled TaskStatus = "cancelled"
|
||||
StatusScheduled TaskStatus = "scheduled"
|
||||
StatusPaused TaskStatus = "paused"
|
||||
StatusCheckpoint TaskStatus = "checkpoint"
|
||||
StatusLongRunning TaskStatus = "long_running"
|
||||
)
|
||||
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskResearch TaskType = "research"
|
||||
TaskCode TaskType = "code"
|
||||
TaskAnalysis TaskType = "analysis"
|
||||
TaskDesign TaskType = "design"
|
||||
TaskDeploy TaskType = "deploy"
|
||||
TaskMonitor TaskType = "monitor"
|
||||
TaskReport TaskType = "report"
|
||||
TaskCommunicate TaskType = "communicate"
|
||||
TaskSchedule TaskType = "schedule"
|
||||
TaskTransform TaskType = "transform"
|
||||
TaskValidate TaskType = "validate"
|
||||
)
|
||||
|
||||
type ComputerTask struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
Query string `json:"query"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Plan *TaskPlan `json:"plan,omitempty"`
|
||||
SubTasks []SubTask `json:"subTasks,omitempty"`
|
||||
Artifacts []Artifact `json:"artifacts,omitempty"`
|
||||
Memory map[string]interface{} `json:"memory,omitempty"`
|
||||
Progress int `json:"progress"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Schedule *Schedule `json:"schedule,omitempty"`
|
||||
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
|
||||
RunCount int `json:"runCount"`
|
||||
TotalCost float64 `json:"totalCost"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
|
||||
DurationMode DurationMode `json:"durationMode"`
|
||||
Checkpoint *Checkpoint `json:"checkpoint,omitempty"`
|
||||
Checkpoints []Checkpoint `json:"checkpoints,omitempty"`
|
||||
MaxDuration time.Duration `json:"maxDuration"`
|
||||
EstimatedEnd *time.Time `json:"estimatedEnd,omitempty"`
|
||||
Iterations int `json:"iterations"`
|
||||
MaxIterations int `json:"maxIterations"`
|
||||
PausedAt *time.Time `json:"pausedAt,omitempty"`
|
||||
ResumedAt *time.Time `json:"resumedAt,omitempty"`
|
||||
TotalRuntime time.Duration `json:"totalRuntime"`
|
||||
HeartbeatAt *time.Time `json:"heartbeatAt,omitempty"`
|
||||
Priority TaskPriority `json:"priority"`
|
||||
ResourceLimits *ResourceLimits `json:"resourceLimits,omitempty"`
|
||||
}
|
||||
|
||||
type DurationMode string
|
||||
|
||||
const (
|
||||
DurationShort DurationMode = "short"
|
||||
DurationMedium DurationMode = "medium"
|
||||
DurationLong DurationMode = "long"
|
||||
DurationExtended DurationMode = "extended"
|
||||
DurationUnlimited DurationMode = "unlimited"
|
||||
)
|
||||
|
||||
type TaskPriority string
|
||||
|
||||
const (
|
||||
PriorityLow TaskPriority = "low"
|
||||
PriorityNormal TaskPriority = "normal"
|
||||
PriorityHigh TaskPriority = "high"
|
||||
PriorityCritical TaskPriority = "critical"
|
||||
)
|
||||
|
||||
type Checkpoint struct {
|
||||
ID string `json:"id"`
|
||||
TaskID string `json:"taskId"`
|
||||
SubTaskIndex int `json:"subTaskIndex"`
|
||||
WaveIndex int `json:"waveIndex"`
|
||||
State map[string]interface{} `json:"state"`
|
||||
Progress int `json:"progress"`
|
||||
Artifacts []string `json:"artifacts"`
|
||||
Memory map[string]interface{} `json:"memory"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
RuntimeSoFar time.Duration `json:"runtimeSoFar"`
|
||||
CostSoFar float64 `json:"costSoFar"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResourceLimits struct {
|
||||
MaxCPU float64 `json:"maxCpu"`
|
||||
MaxMemoryMB int `json:"maxMemoryMb"`
|
||||
MaxDiskMB int `json:"maxDiskMb"`
|
||||
MaxNetworkMbps int `json:"maxNetworkMbps"`
|
||||
MaxCostPerHour float64 `json:"maxCostPerHour"`
|
||||
MaxTotalCost float64 `json:"maxTotalCost"`
|
||||
MaxConcurrent int `json:"maxConcurrent"`
|
||||
IdleTimeoutMins int `json:"idleTimeoutMins"`
|
||||
}
|
||||
|
||||
var DurationModeConfigs = map[DurationMode]struct {
|
||||
MaxDuration time.Duration
|
||||
CheckpointFreq time.Duration
|
||||
HeartbeatFreq time.Duration
|
||||
MaxIterations int
|
||||
}{
|
||||
DurationShort: {30 * time.Minute, 5 * time.Minute, 30 * time.Second, 10},
|
||||
DurationMedium: {4 * time.Hour, 15 * time.Minute, time.Minute, 50},
|
||||
DurationLong: {24 * time.Hour, 30 * time.Minute, 2 * time.Minute, 200},
|
||||
DurationExtended: {7 * 24 * time.Hour, time.Hour, 5 * time.Minute, 1000},
|
||||
DurationUnlimited: {365 * 24 * time.Hour, 4 * time.Hour, 10 * time.Minute, 0},
|
||||
}
|
||||
|
||||
type SubTask struct {
|
||||
ID string `json:"id"`
|
||||
Type TaskType `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Dependencies []string `json:"dependencies,omitempty"`
|
||||
ModelID string `json:"modelId,omitempty"`
|
||||
RequiredCaps []llm.ModelCapability `json:"requiredCaps,omitempty"`
|
||||
Input map[string]interface{} `json:"input,omitempty"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Cost float64 `json:"cost"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
Retries int `json:"retries"`
|
||||
MaxRetries int `json:"maxRetries"`
|
||||
}
|
||||
|
||||
type TaskPlan struct {
|
||||
Query string `json:"query"`
|
||||
Summary string `json:"summary"`
|
||||
SubTasks []SubTask `json:"subTasks"`
|
||||
ExecutionOrder [][]string `json:"executionOrder"`
|
||||
EstimatedCost float64 `json:"estimatedCost"`
|
||||
EstimatedTime int `json:"estimatedTimeSeconds"`
|
||||
}
|
||||
|
||||
type Artifact struct {
|
||||
ID string `json:"id"`
|
||||
TaskID string `json:"taskId"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Content []byte `json:"-"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type Schedule struct {
|
||||
Type string `json:"type"`
|
||||
CronExpr string `json:"cronExpr,omitempty"`
|
||||
Interval int `json:"intervalSeconds,omitempty"`
|
||||
NextRun time.Time `json:"nextRun"`
|
||||
MaxRuns int `json:"maxRuns"`
|
||||
RunCount int `json:"runCount"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
DurationMode DurationMode `json:"durationMode,omitempty"`
|
||||
RetryOnFail bool `json:"retryOnFail"`
|
||||
MaxRetries int `json:"maxRetries"`
|
||||
RetryDelay time.Duration `json:"retryDelay"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
WindowStart string `json:"windowStart,omitempty"`
|
||||
WindowEnd string `json:"windowEnd,omitempty"`
|
||||
Conditions []Condition `json:"conditions,omitempty"`
|
||||
}
|
||||
|
||||
type Condition struct {
|
||||
Type string `json:"type"`
|
||||
Field string `json:"field"`
|
||||
Operator string `json:"operator"`
|
||||
Value interface{} `json:"value"`
|
||||
Params map[string]interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
ScheduleOnce = "once"
|
||||
ScheduleInterval = "interval"
|
||||
ScheduleCron = "cron"
|
||||
ScheduleHourly = "hourly"
|
||||
ScheduleDaily = "daily"
|
||||
ScheduleWeekly = "weekly"
|
||||
ScheduleMonthly = "monthly"
|
||||
ScheduleQuarterly = "quarterly"
|
||||
ScheduleYearly = "yearly"
|
||||
ScheduleContinuous = "continuous"
|
||||
ScheduleOnCondition = "on_condition"
|
||||
)
|
||||
|
||||
type TaskEvent struct {
|
||||
Type string `json:"type"`
|
||||
TaskID string `json:"taskId"`
|
||||
SubTaskID string `json:"subTaskId,omitempty"`
|
||||
Status TaskStatus `json:"status,omitempty"`
|
||||
Progress int `json:"progress,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type ExecuteOptions struct {
|
||||
Async bool `json:"async"`
|
||||
MaxCost float64 `json:"maxCost"`
|
||||
Timeout int `json:"timeoutSeconds"`
|
||||
EnableSandbox bool `json:"enableSandbox"`
|
||||
Schedule *Schedule `json:"schedule,omitempty"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
|
||||
DurationMode DurationMode `json:"durationMode,omitempty"`
|
||||
Priority TaskPriority `json:"priority,omitempty"`
|
||||
ResourceLimits *ResourceLimits `json:"resourceLimits,omitempty"`
|
||||
ResumeFromID string `json:"resumeFromId,omitempty"`
|
||||
EnableBrowser bool `json:"enableBrowser"`
|
||||
BrowserOptions *BrowserOptions `json:"browserOptions,omitempty"`
|
||||
NotifyOnEvents []string `json:"notifyOnEvents,omitempty"`
|
||||
WebhookURL string `json:"webhookUrl,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
type BrowserOptions struct {
|
||||
Headless bool `json:"headless"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
Viewport *Viewport `json:"viewport,omitempty"`
|
||||
ProxyURL string `json:"proxyUrl,omitempty"`
|
||||
Timeout int `json:"timeout"`
|
||||
Screenshots bool `json:"screenshots"`
|
||||
RecordVideo bool `json:"recordVideo"`
|
||||
BlockAds bool `json:"blockAds"`
|
||||
AcceptCookies bool `json:"acceptCookies"`
|
||||
}
|
||||
|
||||
type Viewport struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
type ExecutionResult struct {
|
||||
TaskID string
|
||||
SubTaskID string
|
||||
Output map[string]interface{}
|
||||
Artifacts []Artifact
|
||||
Duration time.Duration
|
||||
Cost float64
|
||||
Error error
|
||||
}
|
||||
|
||||
type SandboxResult struct {
|
||||
Stdout string
|
||||
Stderr string
|
||||
ExitCode int
|
||||
Files map[string][]byte
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
type MemoryEntry struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
TaskID string `json:"taskId,omitempty"`
|
||||
Key string `json:"key"`
|
||||
Value interface{} `json:"value"`
|
||||
Type string `json:"type"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
EventTaskCreated = "task_created"
|
||||
EventTaskStarted = "task_started"
|
||||
EventTaskProgress = "task_progress"
|
||||
EventTaskCompleted = "task_completed"
|
||||
EventTaskFailed = "task_failed"
|
||||
EventSubTaskStart = "subtask_start"
|
||||
EventSubTaskDone = "subtask_done"
|
||||
EventSubTaskFail = "subtask_fail"
|
||||
EventArtifact = "artifact"
|
||||
EventMessage = "message"
|
||||
EventUserInput = "user_input_required"
|
||||
EventCheckpoint = "checkpoint"
|
||||
EventCheckpointSaved = "checkpoint_saved"
|
||||
EventResumed = "resumed"
|
||||
EventPaused = "paused"
|
||||
EventHeartbeat = "heartbeat"
|
||||
EventIteration = "iteration"
|
||||
EventBrowserAction = "browser_action"
|
||||
EventScreenshot = "screenshot"
|
||||
EventResourceAlert = "resource_alert"
|
||||
EventScheduleUpdate = "schedule_update"
|
||||
)
|
||||
|
||||
type BrowserAction struct {
|
||||
ID string `json:"id"`
|
||||
Type BrowserActionType `json:"type"`
|
||||
Selector string `json:"selector,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Value string `json:"value,omitempty"`
|
||||
Options map[string]interface{} `json:"options,omitempty"`
|
||||
Screenshot bool `json:"screenshot"`
|
||||
WaitAfter int `json:"waitAfterMs"`
|
||||
Timeout int `json:"timeoutMs"`
|
||||
Result *BrowserActionResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type BrowserActionType string
|
||||
|
||||
const (
|
||||
BrowserNavigate BrowserActionType = "navigate"
|
||||
BrowserClick BrowserActionType = "click"
|
||||
BrowserType BrowserActionType = "type"
|
||||
BrowserScroll BrowserActionType = "scroll"
|
||||
BrowserScreenshot BrowserActionType = "screenshot"
|
||||
BrowserWait BrowserActionType = "wait"
|
||||
BrowserWaitSelector BrowserActionType = "wait_selector"
|
||||
BrowserExtract BrowserActionType = "extract"
|
||||
BrowserEval BrowserActionType = "eval"
|
||||
BrowserSelect BrowserActionType = "select"
|
||||
BrowserUpload BrowserActionType = "upload"
|
||||
BrowserDownload BrowserActionType = "download"
|
||||
BrowserPDF BrowserActionType = "pdf"
|
||||
BrowserClose BrowserActionType = "close"
|
||||
)
|
||||
|
||||
type BrowserActionResult struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Screenshot string `json:"screenshot,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
PageTitle string `json:"pageTitle,omitempty"`
|
||||
PageURL string `json:"pageUrl,omitempty"`
|
||||
Cookies []map[string]string `json:"cookies,omitempty"`
|
||||
LocalStorage map[string]string `json:"localStorage,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
ArtifactTypeFile = "file"
|
||||
ArtifactTypeCode = "code"
|
||||
ArtifactTypeReport = "report"
|
||||
ArtifactTypeDeployment = "deployment"
|
||||
ArtifactTypeImage = "image"
|
||||
ArtifactTypeData = "data"
|
||||
)
|
||||
|
||||
const (
|
||||
MemoryTypeFact = "fact"
|
||||
MemoryTypePreference = "preference"
|
||||
MemoryTypeContext = "context"
|
||||
MemoryTypeResult = "result"
|
||||
)
|
||||
97
backend/internal/db/article_summary_repo.go
Normal file
97
backend/internal/db/article_summary_repo.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ArticleSummary struct {
|
||||
ID int64 `json:"id"`
|
||||
URLHash string `json:"urlHash"`
|
||||
URL string `json:"url"`
|
||||
Events []string `json:"events"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
}
|
||||
|
||||
type ArticleSummaryRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewArticleSummaryRepository(db *PostgresDB) *ArticleSummaryRepository {
|
||||
return &ArticleSummaryRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ArticleSummaryRepository) hashURL(url string) string {
|
||||
normalized := strings.TrimSpace(url)
|
||||
normalized = strings.TrimSuffix(normalized, "/")
|
||||
normalized = strings.TrimPrefix(normalized, "https://")
|
||||
normalized = strings.TrimPrefix(normalized, "http://")
|
||||
normalized = strings.TrimPrefix(normalized, "www.")
|
||||
|
||||
hash := sha256.Sum256([]byte(normalized))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (r *ArticleSummaryRepository) GetByURL(ctx context.Context, url string) (*ArticleSummary, error) {
|
||||
urlHash := r.hashURL(url)
|
||||
|
||||
query := `
|
||||
SELECT id, url_hash, url, events, created_at, expires_at
|
||||
FROM article_summaries
|
||||
WHERE url_hash = $1 AND expires_at > NOW()
|
||||
`
|
||||
|
||||
var a ArticleSummary
|
||||
var eventsJSON []byte
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, urlHash).Scan(
|
||||
&a.ID, &a.URLHash, &a.URL, &eventsJSON, &a.CreatedAt, &a.ExpiresAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(eventsJSON, &a.Events)
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func (r *ArticleSummaryRepository) Save(ctx context.Context, url string, events []string, ttl time.Duration) error {
|
||||
urlHash := r.hashURL(url)
|
||||
eventsJSON, _ := json.Marshal(events)
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
|
||||
query := `
|
||||
INSERT INTO article_summaries (url_hash, url, events, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (url_hash)
|
||||
DO UPDATE SET
|
||||
events = EXCLUDED.events,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
`
|
||||
|
||||
_, err := r.db.db.ExecContext(ctx, query, urlHash, url, eventsJSON, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ArticleSummaryRepository) Delete(ctx context.Context, url string) error {
|
||||
urlHash := r.hashURL(url)
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM article_summaries WHERE url_hash = $1", urlHash)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ArticleSummaryRepository) CleanupExpired(ctx context.Context) (int64, error) {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM article_summaries WHERE expires_at < NOW()")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
204
backend/internal/db/collection_repo.go
Normal file
204
backend/internal/db/collection_repo.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Collection struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
ContextEnabled bool `json:"contextEnabled"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Items []CollectionItem `json:"items,omitempty"`
|
||||
ItemCount int `json:"itemCount,omitempty"`
|
||||
}
|
||||
|
||||
type CollectionItem struct {
|
||||
ID string `json:"id"`
|
||||
CollectionID string `json:"collectionId"`
|
||||
ItemType string `json:"itemType"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
URL string `json:"url"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
SortOrder int `json:"sortOrder"`
|
||||
}
|
||||
|
||||
type CollectionRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewCollectionRepository(db *PostgresDB) *CollectionRepository {
|
||||
return &CollectionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) Create(ctx context.Context, c *Collection) error {
|
||||
query := `
|
||||
INSERT INTO collections (user_id, name, description, is_public, context_enabled)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
c.UserID, c.Name, c.Description, c.IsPublic, c.ContextEnabled,
|
||||
).Scan(&c.ID, &c.CreatedAt, &c.UpdatedAt)
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetByID(ctx context.Context, id string) (*Collection, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, description, is_public, context_enabled, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM collection_items WHERE collection_id = collections.id) as item_count
|
||||
FROM collections
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var c Collection
|
||||
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&c.ID, &c.UserID, &c.Name, &c.Description, &c.IsPublic,
|
||||
&c.ContextEnabled, &c.CreatedAt, &c.UpdatedAt, &c.ItemCount,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*Collection, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, description, is_public, context_enabled, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM collection_items WHERE collection_id = collections.id) as item_count
|
||||
FROM collections
|
||||
WHERE user_id = $1
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var collections []*Collection
|
||||
for rows.Next() {
|
||||
var c Collection
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.UserID, &c.Name, &c.Description, &c.IsPublic,
|
||||
&c.ContextEnabled, &c.CreatedAt, &c.UpdatedAt, &c.ItemCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collections = append(collections, &c)
|
||||
}
|
||||
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) Update(ctx context.Context, c *Collection) error {
|
||||
query := `
|
||||
UPDATE collections
|
||||
SET name = $2, description = $3, is_public = $4, context_enabled = $5, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
c.ID, c.Name, c.Description, c.IsPublic, c.ContextEnabled,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collections WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) AddItem(ctx context.Context, item *CollectionItem) error {
|
||||
metadataJSON, _ := json.Marshal(item.Metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO collection_items (collection_id, item_type, title, content, url, metadata, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, COALESCE((SELECT MAX(sort_order) + 1 FROM collection_items WHERE collection_id = $1), 0))
|
||||
RETURNING id, created_at, sort_order
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
item.CollectionID, item.ItemType, item.Title, item.Content, item.URL, metadataJSON,
|
||||
).Scan(&item.ID, &item.CreatedAt, &item.SortOrder)
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetItems(ctx context.Context, collectionID string) ([]CollectionItem, error) {
|
||||
query := `
|
||||
SELECT id, collection_id, item_type, title, content, url, metadata, created_at, sort_order
|
||||
FROM collection_items
|
||||
WHERE collection_id = $1
|
||||
ORDER BY sort_order ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []CollectionItem
|
||||
for rows.Next() {
|
||||
var item CollectionItem
|
||||
var metadataJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&item.ID, &item.CollectionID, &item.ItemType, &item.Title,
|
||||
&item.Content, &item.URL, &metadataJSON, &item.CreatedAt, &item.SortOrder,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(metadataJSON, &item.Metadata)
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) RemoveItem(ctx context.Context, itemID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM collection_items WHERE id = $1", itemID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *CollectionRepository) GetCollectionContext(ctx context.Context, collectionID string) (string, error) {
|
||||
items, err := r.GetItems(ctx, collectionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var context string
|
||||
for _, item := range items {
|
||||
switch item.ItemType {
|
||||
case "search":
|
||||
context += "Previous search: " + item.Title + "\n"
|
||||
if item.Content != "" {
|
||||
context += "Summary: " + item.Content + "\n"
|
||||
}
|
||||
case "note":
|
||||
context += "User note: " + item.Content + "\n"
|
||||
case "url":
|
||||
context += "Saved URL: " + item.URL + " - " + item.Title + "\n"
|
||||
case "file":
|
||||
context += "Uploaded file: " + item.Title + "\n"
|
||||
if item.Content != "" {
|
||||
context += "Content: " + item.Content + "\n"
|
||||
}
|
||||
}
|
||||
context += "\n"
|
||||
}
|
||||
|
||||
return context, nil
|
||||
}
|
||||
322
backend/internal/db/computer_artifact_repo.go
Normal file
322
backend/internal/db/computer_artifact_repo.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerArtifactRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerArtifactRepo(db *sql.DB) *ComputerArtifactRepo {
|
||||
return &ComputerArtifactRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_artifacts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
task_id UUID NOT NULL,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255),
|
||||
content BYTEA,
|
||||
url TEXT,
|
||||
size BIGINT DEFAULT 0,
|
||||
mime_type VARCHAR(100),
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_task_id ON computer_artifacts(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_type ON computer_artifacts(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_artifacts_created ON computer_artifacts(created_at DESC);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Create(ctx context.Context, artifact *computer.Artifact) error {
|
||||
metadataJSON, _ := json.Marshal(artifact.Metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_artifacts (id, task_id, type, name, content, url, size, mime_type, metadata, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
artifact.ID,
|
||||
artifact.TaskID,
|
||||
artifact.Type,
|
||||
artifact.Name,
|
||||
artifact.Content,
|
||||
artifact.URL,
|
||||
artifact.Size,
|
||||
artifact.MimeType,
|
||||
metadataJSON,
|
||||
artifact.CreatedAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByID(ctx context.Context, id string) (*computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, content, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var artifact computer.Artifact
|
||||
var content []byte
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&content,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
artifact.Content = content
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
return &artifact, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByTaskID(ctx context.Context, taskID string) ([]computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var artifacts []computer.Artifact
|
||||
|
||||
for rows.Next() {
|
||||
var artifact computer.Artifact
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetByType(ctx context.Context, taskID, artifactType string) ([]computer.Artifact, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, metadata, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1 AND type = $2
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID, artifactType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var artifacts []computer.Artifact
|
||||
|
||||
for rows.Next() {
|
||||
var artifact computer.Artifact
|
||||
var url, mimeType sql.NullString
|
||||
var metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&artifact.ID,
|
||||
&artifact.TaskID,
|
||||
&artifact.Type,
|
||||
&artifact.Name,
|
||||
&url,
|
||||
&artifact.Size,
|
||||
&mimeType,
|
||||
&metadataJSON,
|
||||
&artifact.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
artifact.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
artifact.MimeType = mimeType.String
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &artifact.Metadata)
|
||||
}
|
||||
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetContent(ctx context.Context, id string) ([]byte, error) {
|
||||
query := `SELECT content FROM computer_artifacts WHERE id = $1`
|
||||
var content []byte
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(&content)
|
||||
return content, err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) UpdateURL(ctx context.Context, id, url string) error {
|
||||
query := `UPDATE computer_artifacts SET url = $1 WHERE id = $2`
|
||||
_, err := r.db.ExecContext(ctx, query, url, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_artifacts WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) DeleteByTaskID(ctx context.Context, taskID string) error {
|
||||
query := `DELETE FROM computer_artifacts WHERE task_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM computer_artifacts
|
||||
WHERE created_at < NOW() - INTERVAL '1 day' * $1
|
||||
`
|
||||
result, err := r.db.ExecContext(ctx, query, days)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetTotalSize(ctx context.Context, taskID string) (int64, error) {
|
||||
query := `SELECT COALESCE(SUM(size), 0) FROM computer_artifacts WHERE task_id = $1`
|
||||
var size int64
|
||||
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&size)
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) Count(ctx context.Context, taskID string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_artifacts WHERE task_id = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, taskID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
type ArtifactSummary struct {
|
||||
ID string `json:"id"`
|
||||
TaskID string `json:"taskId"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Size int64 `json:"size"`
|
||||
MimeType string `json:"mimeType"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
func (r *ComputerArtifactRepo) GetSummaries(ctx context.Context, taskID string) ([]ArtifactSummary, error) {
|
||||
query := `
|
||||
SELECT id, task_id, type, name, url, size, mime_type, created_at
|
||||
FROM computer_artifacts
|
||||
WHERE task_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var summaries []ArtifactSummary
|
||||
|
||||
for rows.Next() {
|
||||
var s ArtifactSummary
|
||||
var url, mimeType sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&s.ID,
|
||||
&s.TaskID,
|
||||
&s.Type,
|
||||
&s.Name,
|
||||
&url,
|
||||
&s.Size,
|
||||
&mimeType,
|
||||
&s.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if url.Valid {
|
||||
s.URL = url.String
|
||||
}
|
||||
if mimeType.Valid {
|
||||
s.MimeType = mimeType.String
|
||||
}
|
||||
|
||||
summaries = append(summaries, s)
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
306
backend/internal/db/computer_memory_repo.go
Normal file
306
backend/internal/db/computer_memory_repo.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerMemoryRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerMemoryRepo(db *sql.DB) *ComputerMemoryRepo {
|
||||
return &ComputerMemoryRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_memory (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
task_id UUID,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
value JSONB NOT NULL,
|
||||
type VARCHAR(50),
|
||||
tags TEXT[],
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_user_id ON computer_memory(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_task_id ON computer_memory(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_type ON computer_memory(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_expires ON computer_memory(expires_at) WHERE expires_at IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_memory_key ON computer_memory(key);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Store(ctx context.Context, entry *computer.MemoryEntry) error {
|
||||
valueJSON, err := json.Marshal(entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_memory (id, user_id, task_id, key, value, type, tags, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
value = EXCLUDED.value,
|
||||
type = EXCLUDED.type,
|
||||
tags = EXCLUDED.tags,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
`
|
||||
|
||||
var taskID interface{}
|
||||
if entry.TaskID != "" {
|
||||
taskID = entry.TaskID
|
||||
}
|
||||
|
||||
_, err = r.db.ExecContext(ctx, query,
|
||||
entry.ID,
|
||||
entry.UserID,
|
||||
taskID,
|
||||
entry.Key,
|
||||
valueJSON,
|
||||
entry.Type,
|
||||
entry.Tags,
|
||||
entry.CreatedAt,
|
||||
entry.ExpiresAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByUser(ctx context.Context, userID string, limit int) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByTask(ctx context.Context, taskID string) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE task_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Search(ctx context.Context, userID, query string, limit int) ([]computer.MemoryEntry, error) {
|
||||
searchTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(searchTerms) == 0 {
|
||||
return r.GetByUser(ctx, userID, limit)
|
||||
}
|
||||
|
||||
likePatterns := make([]string, len(searchTerms))
|
||||
args := make([]interface{}, len(searchTerms)+2)
|
||||
args[0] = userID
|
||||
|
||||
for i, term := range searchTerms {
|
||||
likePatterns[i] = "%" + term + "%"
|
||||
args[i+1] = likePatterns[i]
|
||||
}
|
||||
args[len(args)-1] = limit
|
||||
|
||||
var conditions []string
|
||||
for i := range searchTerms {
|
||||
conditions = append(conditions, "(LOWER(key) LIKE $"+string(rune('2'+i))+" OR LOWER(value::text) LIKE $"+string(rune('2'+i))+")")
|
||||
}
|
||||
|
||||
sqlQuery := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (` + strings.Join(conditions, " OR ") + `)
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $` + string(rune('2'+len(searchTerms)))
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return r.GetByUser(ctx, userID, limit)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByType(ctx context.Context, userID, memType string, limit int) ([]computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1 AND type = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, memType, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return r.scanEntries(rows)
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) GetByKey(ctx context.Context, userID, key string) (*computer.MemoryEntry, error) {
|
||||
query := `
|
||||
SELECT id, user_id, task_id, key, value, type, tags, created_at, expires_at
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1 AND key = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var entry computer.MemoryEntry
|
||||
var valueJSON []byte
|
||||
var taskID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var tags []string
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, userID, key).Scan(
|
||||
&entry.ID,
|
||||
&entry.UserID,
|
||||
&taskID,
|
||||
&entry.Key,
|
||||
&valueJSON,
|
||||
&entry.Type,
|
||||
&tags,
|
||||
&entry.CreatedAt,
|
||||
&expiresAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
entry.TaskID = taskID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
entry.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
entry.Tags = tags
|
||||
|
||||
json.Unmarshal(valueJSON, &entry.Value)
|
||||
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_memory WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteByUser(ctx context.Context, userID string) error {
|
||||
query := `DELETE FROM computer_memory WHERE user_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteByTask(ctx context.Context, taskID string) error {
|
||||
query := `DELETE FROM computer_memory WHERE task_id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) DeleteExpired(ctx context.Context) (int64, error) {
|
||||
query := `DELETE FROM computer_memory WHERE expires_at IS NOT NULL AND expires_at < NOW()`
|
||||
result, err := r.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) scanEntries(rows *sql.Rows) ([]computer.MemoryEntry, error) {
|
||||
var entries []computer.MemoryEntry
|
||||
|
||||
for rows.Next() {
|
||||
var entry computer.MemoryEntry
|
||||
var valueJSON []byte
|
||||
var taskID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var tags []string
|
||||
|
||||
err := rows.Scan(
|
||||
&entry.ID,
|
||||
&entry.UserID,
|
||||
&taskID,
|
||||
&entry.Key,
|
||||
&valueJSON,
|
||||
&entry.Type,
|
||||
&tags,
|
||||
&entry.CreatedAt,
|
||||
&expiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
entry.TaskID = taskID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
entry.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
entry.Tags = tags
|
||||
|
||||
json.Unmarshal(valueJSON, &entry.Value)
|
||||
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) Count(ctx context.Context, userID string) (int64, error) {
|
||||
query := `
|
||||
SELECT COUNT(*)
|
||||
FROM computer_memory
|
||||
WHERE user_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ComputerMemoryRepo) UpdateExpiry(ctx context.Context, id string, expiresAt time.Time) error {
|
||||
query := `UPDATE computer_memory SET expires_at = $1 WHERE id = $2`
|
||||
_, err := r.db.ExecContext(ctx, query, expiresAt, id)
|
||||
return err
|
||||
}
|
||||
411
backend/internal/db/computer_task_repo.go
Normal file
411
backend/internal/db/computer_task_repo.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/computer"
|
||||
)
|
||||
|
||||
type ComputerTaskRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewComputerTaskRepo(db *sql.DB) *ComputerTaskRepo {
|
||||
return &ComputerTaskRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Migrate() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS computer_tasks (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
query TEXT NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
plan JSONB,
|
||||
sub_tasks JSONB,
|
||||
artifacts JSONB,
|
||||
memory JSONB,
|
||||
progress INT DEFAULT 0,
|
||||
message TEXT,
|
||||
error TEXT,
|
||||
schedule JSONB,
|
||||
next_run_at TIMESTAMPTZ,
|
||||
run_count INT DEFAULT 0,
|
||||
total_cost DECIMAL(10,6) DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
completed_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_user_id ON computer_tasks(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_status ON computer_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_next_run ON computer_tasks(next_run_at) WHERE next_run_at IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_computer_tasks_created ON computer_tasks(created_at DESC);
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Create(ctx context.Context, task *computer.ComputerTask) error {
|
||||
planJSON, _ := json.Marshal(task.Plan)
|
||||
subTasksJSON, _ := json.Marshal(task.SubTasks)
|
||||
artifactsJSON, _ := json.Marshal(task.Artifacts)
|
||||
memoryJSON, _ := json.Marshal(task.Memory)
|
||||
scheduleJSON, _ := json.Marshal(task.Schedule)
|
||||
|
||||
query := `
|
||||
INSERT INTO computer_tasks (
|
||||
id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
task.ID,
|
||||
task.UserID,
|
||||
task.Query,
|
||||
task.Status,
|
||||
planJSON,
|
||||
subTasksJSON,
|
||||
artifactsJSON,
|
||||
memoryJSON,
|
||||
task.Progress,
|
||||
task.Message,
|
||||
task.Error,
|
||||
scheduleJSON,
|
||||
task.NextRunAt,
|
||||
task.RunCount,
|
||||
task.TotalCost,
|
||||
task.CreatedAt,
|
||||
task.UpdatedAt,
|
||||
task.CompletedAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Update(ctx context.Context, task *computer.ComputerTask) error {
|
||||
planJSON, _ := json.Marshal(task.Plan)
|
||||
subTasksJSON, _ := json.Marshal(task.SubTasks)
|
||||
artifactsJSON, _ := json.Marshal(task.Artifacts)
|
||||
memoryJSON, _ := json.Marshal(task.Memory)
|
||||
scheduleJSON, _ := json.Marshal(task.Schedule)
|
||||
|
||||
query := `
|
||||
UPDATE computer_tasks SET
|
||||
status = $1,
|
||||
plan = $2,
|
||||
sub_tasks = $3,
|
||||
artifacts = $4,
|
||||
memory = $5,
|
||||
progress = $6,
|
||||
message = $7,
|
||||
error = $8,
|
||||
schedule = $9,
|
||||
next_run_at = $10,
|
||||
run_count = $11,
|
||||
total_cost = $12,
|
||||
updated_at = $13,
|
||||
completed_at = $14
|
||||
WHERE id = $15
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
task.Status,
|
||||
planJSON,
|
||||
subTasksJSON,
|
||||
artifactsJSON,
|
||||
memoryJSON,
|
||||
task.Progress,
|
||||
task.Message,
|
||||
task.Error,
|
||||
scheduleJSON,
|
||||
task.NextRunAt,
|
||||
task.RunCount,
|
||||
task.TotalCost,
|
||||
time.Now(),
|
||||
task.CompletedAt,
|
||||
task.ID,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetByID(ctx context.Context, id string) (*computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []computer.ComputerTask
|
||||
|
||||
for rows.Next() {
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) GetScheduled(ctx context.Context) ([]computer.ComputerTask, error) {
|
||||
query := `
|
||||
SELECT id, user_id, query, status, plan, sub_tasks, artifacts, memory,
|
||||
progress, message, error, schedule, next_run_at, run_count, total_cost,
|
||||
created_at, updated_at, completed_at
|
||||
FROM computer_tasks
|
||||
WHERE status = 'scheduled' AND schedule IS NOT NULL
|
||||
ORDER BY next_run_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []computer.ComputerTask
|
||||
|
||||
for rows.Next() {
|
||||
var task computer.ComputerTask
|
||||
var planJSON, subTasksJSON, artifactsJSON, memoryJSON, scheduleJSON []byte
|
||||
var message, errStr sql.NullString
|
||||
var nextRunAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.UserID,
|
||||
&task.Query,
|
||||
&task.Status,
|
||||
&planJSON,
|
||||
&subTasksJSON,
|
||||
&artifactsJSON,
|
||||
&memoryJSON,
|
||||
&task.Progress,
|
||||
&message,
|
||||
&errStr,
|
||||
&scheduleJSON,
|
||||
&nextRunAt,
|
||||
&task.RunCount,
|
||||
&task.TotalCost,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(planJSON) > 0 {
|
||||
json.Unmarshal(planJSON, &task.Plan)
|
||||
}
|
||||
if len(subTasksJSON) > 0 {
|
||||
json.Unmarshal(subTasksJSON, &task.SubTasks)
|
||||
}
|
||||
if len(artifactsJSON) > 0 {
|
||||
json.Unmarshal(artifactsJSON, &task.Artifacts)
|
||||
}
|
||||
if len(memoryJSON) > 0 {
|
||||
json.Unmarshal(memoryJSON, &task.Memory)
|
||||
}
|
||||
if len(scheduleJSON) > 0 {
|
||||
json.Unmarshal(scheduleJSON, &task.Schedule)
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
task.Message = message.String
|
||||
}
|
||||
if errStr.Valid {
|
||||
task.Error = errStr.String
|
||||
}
|
||||
if nextRunAt.Valid {
|
||||
task.NextRunAt = &nextRunAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM computer_tasks WHERE id = $1`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) DeleteOlderThan(ctx context.Context, days int) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM computer_tasks
|
||||
WHERE created_at < NOW() - INTERVAL '%d days'
|
||||
AND status IN ('completed', 'failed', 'cancelled')
|
||||
`
|
||||
result, err := r.db.ExecContext(ctx, fmt.Sprintf(query, days))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_tasks WHERE user_id = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ComputerTaskRepo) CountByStatus(ctx context.Context, status string) (int64, error) {
|
||||
query := `SELECT COUNT(*) FROM computer_tasks WHERE status = $1`
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, status).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
177
backend/internal/db/digest_repo.go
Normal file
177
backend/internal/db/digest_repo.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DigestCitation struct {
|
||||
Index int `json:"index"`
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type Digest struct {
|
||||
ID int64 `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Region string `json:"region"`
|
||||
ClusterTitle string `json:"clusterTitle"`
|
||||
SummaryRu string `json:"summaryRu"`
|
||||
Citations []DigestCitation `json:"citations"`
|
||||
SourcesCount int `json:"sourcesCount"`
|
||||
FollowUp []string `json:"followUp"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
ShortDescription string `json:"shortDescription"`
|
||||
MainURL string `json:"mainUrl"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type DigestRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewDigestRepository(db *PostgresDB) *DigestRepository {
|
||||
return &DigestRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *DigestRepository) GetByTopicRegionTitle(ctx context.Context, topic, region, title string) (*Digest, error) {
|
||||
query := `
|
||||
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
|
||||
follow_up, thumbnail, short_description, main_url, created_at, updated_at
|
||||
FROM digests
|
||||
WHERE topic = $1 AND region = $2 AND cluster_title = $3
|
||||
`
|
||||
|
||||
var d Digest
|
||||
var citationsJSON, followUpJSON []byte
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, topic, region, title).Scan(
|
||||
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
|
||||
&citationsJSON, &d.SourcesCount, &followUpJSON,
|
||||
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
|
||||
&d.CreatedAt, &d.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(citationsJSON, &d.Citations)
|
||||
json.Unmarshal(followUpJSON, &d.FollowUp)
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *DigestRepository) GetByURL(ctx context.Context, url string) (*Digest, error) {
|
||||
query := `
|
||||
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
|
||||
follow_up, thumbnail, short_description, main_url, created_at, updated_at
|
||||
FROM digests
|
||||
WHERE main_url = $1
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var d Digest
|
||||
var citationsJSON, followUpJSON []byte
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, url).Scan(
|
||||
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
|
||||
&citationsJSON, &d.SourcesCount, &followUpJSON,
|
||||
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
|
||||
&d.CreatedAt, &d.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(citationsJSON, &d.Citations)
|
||||
json.Unmarshal(followUpJSON, &d.FollowUp)
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *DigestRepository) GetByTopicRegion(ctx context.Context, topic, region string, limit int) ([]*Digest, error) {
|
||||
query := `
|
||||
SELECT id, topic, region, cluster_title, summary_ru, citations, sources_count,
|
||||
follow_up, thumbnail, short_description, main_url, created_at, updated_at
|
||||
FROM digests
|
||||
WHERE topic = $1 AND region = $2
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, topic, region, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var digests []*Digest
|
||||
for rows.Next() {
|
||||
var d Digest
|
||||
var citationsJSON, followUpJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&d.ID, &d.Topic, &d.Region, &d.ClusterTitle, &d.SummaryRu,
|
||||
&citationsJSON, &d.SourcesCount, &followUpJSON,
|
||||
&d.Thumbnail, &d.ShortDescription, &d.MainURL,
|
||||
&d.CreatedAt, &d.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(citationsJSON, &d.Citations)
|
||||
json.Unmarshal(followUpJSON, &d.FollowUp)
|
||||
digests = append(digests, &d)
|
||||
}
|
||||
|
||||
return digests, nil
|
||||
}
|
||||
|
||||
func (r *DigestRepository) Upsert(ctx context.Context, d *Digest) error {
|
||||
citationsJSON, _ := json.Marshal(d.Citations)
|
||||
followUpJSON, _ := json.Marshal(d.FollowUp)
|
||||
|
||||
query := `
|
||||
INSERT INTO digests (topic, region, cluster_title, summary_ru, citations, sources_count,
|
||||
follow_up, thumbnail, short_description, main_url)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
ON CONFLICT (topic, region, cluster_title)
|
||||
DO UPDATE SET
|
||||
summary_ru = EXCLUDED.summary_ru,
|
||||
citations = EXCLUDED.citations,
|
||||
sources_count = EXCLUDED.sources_count,
|
||||
follow_up = EXCLUDED.follow_up,
|
||||
thumbnail = EXCLUDED.thumbnail,
|
||||
short_description = EXCLUDED.short_description,
|
||||
main_url = EXCLUDED.main_url,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
d.Topic, d.Region, d.ClusterTitle, d.SummaryRu,
|
||||
citationsJSON, d.SourcesCount, followUpJSON,
|
||||
d.Thumbnail, d.ShortDescription, d.MainURL,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *DigestRepository) DeleteByTopicRegion(ctx context.Context, topic, region string) (int64, error) {
|
||||
result, err := r.db.db.ExecContext(ctx,
|
||||
"DELETE FROM digests WHERE topic = $1 AND region = $2",
|
||||
topic, region,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
149
backend/internal/db/file_repo.go
Normal file
149
backend/internal/db/file_repo.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UploadedFile struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
Filename string `json:"filename"`
|
||||
FileType string `json:"fileType"`
|
||||
FileSize int64 `json:"fileSize"`
|
||||
StoragePath string `json:"storagePath"`
|
||||
ExtractedText string `json:"extractedText,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type FileRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewFileRepository(db *PostgresDB) *FileRepository {
|
||||
return &FileRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *FileRepository) Create(ctx context.Context, f *UploadedFile) error {
|
||||
metadataJSON, _ := json.Marshal(f.Metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO uploaded_files (user_id, filename, file_type, file_size, storage_path, extracted_text, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
f.UserID, f.Filename, f.FileType, f.FileSize, f.StoragePath, f.ExtractedText, metadataJSON,
|
||||
).Scan(&f.ID, &f.CreatedAt)
|
||||
}
|
||||
|
||||
func (r *FileRepository) GetByID(ctx context.Context, id string) (*UploadedFile, error) {
|
||||
query := `
|
||||
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
|
||||
FROM uploaded_files
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var f UploadedFile
|
||||
var metadataJSON []byte
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
|
||||
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(metadataJSON, &f.Metadata)
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func (r *FileRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*UploadedFile, error) {
|
||||
query := `
|
||||
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
|
||||
FROM uploaded_files
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var files []*UploadedFile
|
||||
for rows.Next() {
|
||||
var f UploadedFile
|
||||
var metadataJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
|
||||
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(metadataJSON, &f.Metadata)
|
||||
files = append(files, &f)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (r *FileRepository) UpdateExtractedText(ctx context.Context, id, text string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE uploaded_files SET extracted_text = $2 WHERE id = $1",
|
||||
id, text,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *FileRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM uploaded_files WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *FileRepository) GetByIDs(ctx context.Context, ids []string) ([]*UploadedFile, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, filename, file_type, file_size, storage_path, extracted_text, metadata, created_at
|
||||
FROM uploaded_files
|
||||
WHERE id = ANY($1)
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var files []*UploadedFile
|
||||
for rows.Next() {
|
||||
var f UploadedFile
|
||||
var metadataJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&f.ID, &f.UserID, &f.Filename, &f.FileType, &f.FileSize,
|
||||
&f.StoragePath, &f.ExtractedText, &metadataJSON, &f.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(metadataJSON, &f.Metadata)
|
||||
files = append(files, &f)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
170
backend/internal/db/memory_repo.go
Normal file
170
backend/internal/db/memory_repo.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserMemory struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
MemoryType string `json:"memoryType"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Importance int `json:"importance"`
|
||||
LastUsed time.Time `json:"lastUsed"`
|
||||
UseCount int `json:"useCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type MemoryRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewMemoryRepository(db *PostgresDB) *MemoryRepository {
|
||||
return &MemoryRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) RunMigrations(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS user_memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
memory_type VARCHAR(50) NOT NULL,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
importance INT DEFAULT 5,
|
||||
last_used TIMESTAMPTZ DEFAULT NOW(),
|
||||
use_count INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(user_id, memory_type, key)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_memories_user ON user_memories(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_memories_type ON user_memories(user_id, memory_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_memories_importance ON user_memories(user_id, importance DESC)`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) Save(ctx context.Context, mem *UserMemory) error {
|
||||
metadataJSON, _ := json.Marshal(mem.Metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO user_memories (user_id, memory_type, key, value, metadata, importance)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_id, memory_type, key)
|
||||
DO UPDATE SET
|
||||
value = EXCLUDED.value,
|
||||
metadata = EXCLUDED.metadata,
|
||||
importance = EXCLUDED.importance,
|
||||
updated_at = NOW()
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
mem.UserID, mem.MemoryType, mem.Key, mem.Value, metadataJSON, mem.Importance,
|
||||
).Scan(&mem.ID, &mem.CreatedAt, &mem.UpdatedAt)
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) GetByUserID(ctx context.Context, userID string, memoryType string, limit int) ([]*UserMemory, error) {
|
||||
query := `
|
||||
SELECT id, user_id, memory_type, key, value, metadata, importance, last_used, use_count, created_at, updated_at
|
||||
FROM user_memories
|
||||
WHERE user_id = $1
|
||||
`
|
||||
args := []interface{}{userID}
|
||||
|
||||
if memoryType != "" {
|
||||
query += " AND memory_type = $2"
|
||||
args = append(args, memoryType)
|
||||
}
|
||||
|
||||
query += " ORDER BY importance DESC, last_used DESC"
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT $" + string(rune('0'+len(args)+1))
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var memories []*UserMemory
|
||||
for rows.Next() {
|
||||
var mem UserMemory
|
||||
var metadataJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&mem.ID, &mem.UserID, &mem.MemoryType, &mem.Key, &mem.Value,
|
||||
&metadataJSON, &mem.Importance, &mem.LastUsed, &mem.UseCount,
|
||||
&mem.CreatedAt, &mem.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(metadataJSON, &mem.Metadata)
|
||||
memories = append(memories, &mem)
|
||||
}
|
||||
|
||||
return memories, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) GetContextForUser(ctx context.Context, userID string) (string, error) {
|
||||
memories, err := r.GetByUserID(ctx, userID, "", 20)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var context string
|
||||
for _, mem := range memories {
|
||||
switch mem.MemoryType {
|
||||
case "preference":
|
||||
context += "User preference: " + mem.Key + " = " + mem.Value + "\n"
|
||||
case "fact":
|
||||
context += "Known fact about user: " + mem.Value + "\n"
|
||||
case "instruction":
|
||||
context += "User instruction: " + mem.Value + "\n"
|
||||
case "interest":
|
||||
context += "User interest: " + mem.Value + "\n"
|
||||
default:
|
||||
context += mem.Key + ": " + mem.Value + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
return context, nil
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) IncrementUseCount(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE user_memories SET use_count = use_count + 1, last_used = NOW() WHERE id = $1",
|
||||
id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MemoryRepository) DeleteByUserID(ctx context.Context, userID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM user_memories WHERE user_id = $1", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func ExtractMemoriesFromConversation(ctx context.Context, llmClient interface{}, conversation, answer string) ([]UserMemory, error) {
|
||||
return nil, nil
|
||||
}
|
||||
219
backend/internal/db/page_repo.go
Normal file
219
backend/internal/db/page_repo.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gooseek/backend/internal/pages"
|
||||
)
|
||||
|
||||
type PageRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewPageRepository(db *PostgresDB) *PageRepository {
|
||||
return &PageRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PageRepository) RunMigrations(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS pages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
thread_id UUID REFERENCES threads(id) ON DELETE SET NULL,
|
||||
title VARCHAR(500) NOT NULL,
|
||||
subtitle TEXT,
|
||||
sections JSONB NOT NULL DEFAULT '[]',
|
||||
sources JSONB NOT NULL DEFAULT '[]',
|
||||
thumbnail TEXT,
|
||||
is_public BOOLEAN DEFAULT FALSE,
|
||||
share_id VARCHAR(100) UNIQUE,
|
||||
view_count INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_pages_user ON pages(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_pages_share ON pages(share_id)`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) Create(ctx context.Context, p *pages.Page) error {
|
||||
sectionsJSON, _ := json.Marshal(p.Sections)
|
||||
sourcesJSON, _ := json.Marshal(p.Sources)
|
||||
|
||||
query := `
|
||||
INSERT INTO pages (user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var threadID *string
|
||||
if p.ThreadID != "" {
|
||||
threadID = &p.ThreadID
|
||||
}
|
||||
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
p.UserID, threadID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic,
|
||||
).Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt)
|
||||
}
|
||||
|
||||
func (r *PageRepository) GetByID(ctx context.Context, id string) (*pages.Page, error) {
|
||||
query := `
|
||||
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
|
||||
FROM pages
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var p pages.Page
|
||||
var sectionsJSON, sourcesJSON []byte
|
||||
var threadID, shareID sql.NullString
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
|
||||
§ionsJSON, &sourcesJSON, &p.Thumbnail,
|
||||
&p.IsPublic, &shareID, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(sectionsJSON, &p.Sections)
|
||||
json.Unmarshal(sourcesJSON, &p.Sources)
|
||||
|
||||
if threadID.Valid {
|
||||
p.ThreadID = threadID.String
|
||||
}
|
||||
if shareID.Valid {
|
||||
p.ShareID = shareID.String
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) GetByShareID(ctx context.Context, shareID string) (*pages.Page, error) {
|
||||
query := `
|
||||
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
|
||||
FROM pages
|
||||
WHERE share_id = $1 AND is_public = true
|
||||
`
|
||||
|
||||
var p pages.Page
|
||||
var sectionsJSON, sourcesJSON []byte
|
||||
var threadID, shareIDVal sql.NullString
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, shareID).Scan(
|
||||
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
|
||||
§ionsJSON, &sourcesJSON, &p.Thumbnail,
|
||||
&p.IsPublic, &shareIDVal, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(sectionsJSON, &p.Sections)
|
||||
json.Unmarshal(sourcesJSON, &p.Sources)
|
||||
|
||||
if threadID.Valid {
|
||||
p.ThreadID = threadID.String
|
||||
}
|
||||
if shareIDVal.Valid {
|
||||
p.ShareID = shareIDVal.String
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*pages.Page, error) {
|
||||
query := `
|
||||
SELECT id, user_id, thread_id, title, subtitle, sections, sources, thumbnail, is_public, share_id, view_count, created_at, updated_at
|
||||
FROM pages
|
||||
WHERE user_id = $1
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pagesList []*pages.Page
|
||||
for rows.Next() {
|
||||
var p pages.Page
|
||||
var sectionsJSON, sourcesJSON []byte
|
||||
var threadID, shareID sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.UserID, &threadID, &p.Title, &p.Subtitle,
|
||||
§ionsJSON, &sourcesJSON, &p.Thumbnail,
|
||||
&p.IsPublic, &shareID, &p.ViewCount, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(sectionsJSON, &p.Sections)
|
||||
json.Unmarshal(sourcesJSON, &p.Sources)
|
||||
|
||||
if threadID.Valid {
|
||||
p.ThreadID = threadID.String
|
||||
}
|
||||
if shareID.Valid {
|
||||
p.ShareID = shareID.String
|
||||
}
|
||||
|
||||
pagesList = append(pagesList, &p)
|
||||
}
|
||||
|
||||
return pagesList, nil
|
||||
}
|
||||
|
||||
func (r *PageRepository) Update(ctx context.Context, p *pages.Page) error {
|
||||
sectionsJSON, _ := json.Marshal(p.Sections)
|
||||
sourcesJSON, _ := json.Marshal(p.Sources)
|
||||
|
||||
query := `
|
||||
UPDATE pages
|
||||
SET title = $2, subtitle = $3, sections = $4, sources = $5, thumbnail = $6, is_public = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
p.ID, p.Title, p.Subtitle, sectionsJSON, sourcesJSON, p.Thumbnail, p.IsPublic,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PageRepository) SetShareID(ctx context.Context, pageID, shareID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE pages SET share_id = $2, is_public = true WHERE id = $1",
|
||||
pageID, shareID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PageRepository) IncrementViewCount(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE pages SET view_count = view_count + 1 WHERE id = $1",
|
||||
id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PageRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM pages WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
134
backend/internal/db/postgres.go
Normal file
134
backend/internal/db/postgres.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type PostgresDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewPostgresDB(databaseURL string) (*PostgresDB, error) {
|
||||
db, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &PostgresDB{db: db}, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
func (p *PostgresDB) DB() *sql.DB {
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PostgresDB) RunMigrations(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS digests (
|
||||
id SERIAL PRIMARY KEY,
|
||||
topic VARCHAR(100) NOT NULL,
|
||||
region VARCHAR(50) NOT NULL,
|
||||
cluster_title VARCHAR(500) NOT NULL,
|
||||
summary_ru TEXT NOT NULL,
|
||||
citations JSONB DEFAULT '[]',
|
||||
sources_count INT DEFAULT 0,
|
||||
follow_up JSONB DEFAULT '[]',
|
||||
thumbnail TEXT,
|
||||
short_description TEXT,
|
||||
main_url TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(topic, region, cluster_title)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_digests_topic_region ON digests(topic, region)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_digests_main_url ON digests(main_url)`,
|
||||
`CREATE TABLE IF NOT EXISTS article_summaries (
|
||||
id SERIAL PRIMARY KEY,
|
||||
url_hash VARCHAR(64) NOT NULL UNIQUE,
|
||||
url TEXT NOT NULL,
|
||||
events JSONB NOT NULL DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days'
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_article_summaries_url_hash ON article_summaries(url_hash)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_article_summaries_expires ON article_summaries(expires_at)`,
|
||||
`CREATE TABLE IF NOT EXISTS collections (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
is_public BOOLEAN DEFAULT FALSE,
|
||||
context_enabled BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_collections_user ON collections(user_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS collection_items (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
collection_id UUID NOT NULL REFERENCES collections(id) ON DELETE CASCADE,
|
||||
item_type VARCHAR(50) NOT NULL,
|
||||
title VARCHAR(500),
|
||||
content TEXT,
|
||||
url TEXT,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
sort_order INT DEFAULT 0
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_collection_items_collection ON collection_items(collection_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS uploaded_files (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
file_type VARCHAR(100) NOT NULL,
|
||||
file_size BIGINT NOT NULL,
|
||||
storage_path TEXT NOT NULL,
|
||||
extracted_text TEXT,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_uploaded_files_user ON uploaded_files(user_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS research_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID,
|
||||
collection_id UUID REFERENCES collections(id) ON DELETE SET NULL,
|
||||
query TEXT NOT NULL,
|
||||
focus_mode VARCHAR(50) DEFAULT 'all',
|
||||
optimization_mode VARCHAR(50) DEFAULT 'balanced',
|
||||
sources JSONB DEFAULT '[]',
|
||||
response_blocks JSONB DEFAULT '[]',
|
||||
final_answer TEXT,
|
||||
citations JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
completed_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_research_sessions_user ON research_sessions(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_research_sessions_collection ON research_sessions(collection_id)`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if _, err := p.db.ExecContext(ctx, migration); err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
163
backend/internal/db/space_repo.go
Normal file
163
backend/internal/db/space_repo.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Space struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
Color string `json:"color"`
|
||||
CustomInstructions string `json:"customInstructions"`
|
||||
DefaultFocusMode string `json:"defaultFocusMode"`
|
||||
DefaultModel string `json:"defaultModel"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ThreadCount int `json:"threadCount,omitempty"`
|
||||
}
|
||||
|
||||
type SpaceRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewSpaceRepository(db *PostgresDB) *SpaceRepository {
|
||||
return &SpaceRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) RunMigrations(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS spaces (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
icon VARCHAR(50),
|
||||
color VARCHAR(20),
|
||||
custom_instructions TEXT,
|
||||
default_focus_mode VARCHAR(50) DEFAULT 'all',
|
||||
default_model VARCHAR(100),
|
||||
is_public BOOLEAN DEFAULT FALSE,
|
||||
settings JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_spaces_user ON spaces(user_id)`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) Create(ctx context.Context, s *Space) error {
|
||||
settingsJSON, _ := json.Marshal(s.Settings)
|
||||
|
||||
query := `
|
||||
INSERT INTO spaces (user_id, name, description, icon, color, custom_instructions, default_focus_mode, default_model, is_public, settings)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
s.UserID, s.Name, s.Description, s.Icon, s.Color,
|
||||
s.CustomInstructions, s.DefaultFocusMode, s.DefaultModel,
|
||||
s.IsPublic, settingsJSON,
|
||||
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetByID(ctx context.Context, id string) (*Space, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, description, icon, color, custom_instructions,
|
||||
default_focus_mode, default_model, is_public, settings, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM threads WHERE space_id = spaces.id) as thread_count
|
||||
FROM spaces
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var s Space
|
||||
var settingsJSON []byte
|
||||
|
||||
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&s.ID, &s.UserID, &s.Name, &s.Description, &s.Icon, &s.Color,
|
||||
&s.CustomInstructions, &s.DefaultFocusMode, &s.DefaultModel,
|
||||
&s.IsPublic, &settingsJSON, &s.CreatedAt, &s.UpdatedAt, &s.ThreadCount,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(settingsJSON, &s.Settings)
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetByUserID(ctx context.Context, userID string) ([]*Space, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, description, icon, color, custom_instructions,
|
||||
default_focus_mode, default_model, is_public, settings, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM threads WHERE space_id = spaces.id) as thread_count
|
||||
FROM spaces
|
||||
WHERE user_id = $1
|
||||
ORDER BY updated_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var spaces []*Space
|
||||
for rows.Next() {
|
||||
var s Space
|
||||
var settingsJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&s.ID, &s.UserID, &s.Name, &s.Description, &s.Icon, &s.Color,
|
||||
&s.CustomInstructions, &s.DefaultFocusMode, &s.DefaultModel,
|
||||
&s.IsPublic, &settingsJSON, &s.CreatedAt, &s.UpdatedAt, &s.ThreadCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(settingsJSON, &s.Settings)
|
||||
spaces = append(spaces, &s)
|
||||
}
|
||||
|
||||
return spaces, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) Update(ctx context.Context, s *Space) error {
|
||||
settingsJSON, _ := json.Marshal(s.Settings)
|
||||
|
||||
query := `
|
||||
UPDATE spaces
|
||||
SET name = $2, description = $3, icon = $4, color = $5,
|
||||
custom_instructions = $6, default_focus_mode = $7, default_model = $8,
|
||||
is_public = $9, settings = $10, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query,
|
||||
s.ID, s.Name, s.Description, s.Icon, s.Color,
|
||||
s.CustomInstructions, s.DefaultFocusMode, s.DefaultModel,
|
||||
s.IsPublic, settingsJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM spaces WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
270
backend/internal/db/thread_repo.go
Normal file
270
backend/internal/db/thread_repo.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Thread struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
SpaceID *string `json:"spaceId,omitempty"`
|
||||
Title string `json:"title"`
|
||||
FocusMode string `json:"focusMode"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
ShareID *string `json:"shareId,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Messages []ThreadMessage `json:"messages,omitempty"`
|
||||
MessageCount int `json:"messageCount,omitempty"`
|
||||
}
|
||||
|
||||
type ThreadMessage struct {
|
||||
ID string `json:"id"`
|
||||
ThreadID string `json:"threadId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Sources []ThreadSource `json:"sources,omitempty"`
|
||||
Widgets []map[string]interface{} `json:"widgets,omitempty"`
|
||||
RelatedQuestions []string `json:"relatedQuestions,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
TokensUsed int `json:"tokensUsed,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type ThreadSource struct {
|
||||
Index int `json:"index"`
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Domain string `json:"domain"`
|
||||
Snippet string `json:"snippet,omitempty"`
|
||||
}
|
||||
|
||||
type ThreadRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewThreadRepository(db *PostgresDB) *ThreadRepository {
|
||||
return &ThreadRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) RunMigrations(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS threads (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
space_id UUID REFERENCES spaces(id) ON DELETE SET NULL,
|
||||
title VARCHAR(500) NOT NULL DEFAULT 'New Thread',
|
||||
focus_mode VARCHAR(50) DEFAULT 'all',
|
||||
is_public BOOLEAN DEFAULT FALSE,
|
||||
share_id VARCHAR(100) UNIQUE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_threads_user ON threads(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_threads_space ON threads(space_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_threads_share ON threads(share_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS thread_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
thread_id UUID NOT NULL REFERENCES threads(id) ON DELETE CASCADE,
|
||||
role VARCHAR(20) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
sources JSONB DEFAULT '[]',
|
||||
widgets JSONB DEFAULT '[]',
|
||||
related_questions JSONB DEFAULT '[]',
|
||||
model VARCHAR(100),
|
||||
tokens_used INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_thread_messages_thread ON thread_messages(thread_id)`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if _, err := r.db.db.ExecContext(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) Create(ctx context.Context, t *Thread) error {
|
||||
query := `
|
||||
INSERT INTO threads (user_id, space_id, title, focus_mode, is_public)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
t.UserID, t.SpaceID, t.Title, t.FocusMode, t.IsPublic,
|
||||
).Scan(&t.ID, &t.CreatedAt, &t.UpdatedAt)
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GetByID(ctx context.Context, id string) (*Thread, error) {
|
||||
query := `
|
||||
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM thread_messages WHERE thread_id = threads.id) as message_count
|
||||
FROM threads
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var t Thread
|
||||
err := r.db.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
|
||||
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt, &t.MessageCount,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GetByShareID(ctx context.Context, shareID string) (*Thread, error) {
|
||||
query := `
|
||||
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at
|
||||
FROM threads
|
||||
WHERE share_id = $1 AND is_public = true
|
||||
`
|
||||
|
||||
var t Thread
|
||||
err := r.db.db.QueryRowContext(ctx, query, shareID).Scan(
|
||||
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
|
||||
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*Thread, error) {
|
||||
query := `
|
||||
SELECT id, user_id, space_id, title, focus_mode, is_public, share_id, created_at, updated_at,
|
||||
(SELECT COUNT(*) FROM thread_messages WHERE thread_id = threads.id) as message_count
|
||||
FROM threads
|
||||
WHERE user_id = $1
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var threads []*Thread
|
||||
for rows.Next() {
|
||||
var t Thread
|
||||
if err := rows.Scan(
|
||||
&t.ID, &t.UserID, &t.SpaceID, &t.Title, &t.FocusMode,
|
||||
&t.IsPublic, &t.ShareID, &t.CreatedAt, &t.UpdatedAt, &t.MessageCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
threads = append(threads, &t)
|
||||
}
|
||||
|
||||
return threads, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) Update(ctx context.Context, t *Thread) error {
|
||||
query := `
|
||||
UPDATE threads
|
||||
SET title = $2, focus_mode = $3, is_public = $4, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query, t.ID, t.Title, t.FocusMode, t.IsPublic)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) SetShareID(ctx context.Context, threadID, shareID string) error {
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET share_id = $2, is_public = true WHERE id = $1",
|
||||
threadID, shareID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM threads WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) AddMessage(ctx context.Context, msg *ThreadMessage) error {
|
||||
sourcesJSON, _ := json.Marshal(msg.Sources)
|
||||
widgetsJSON, _ := json.Marshal(msg.Widgets)
|
||||
relatedJSON, _ := json.Marshal(msg.RelatedQuestions)
|
||||
|
||||
query := `
|
||||
INSERT INTO thread_messages (thread_id, role, content, sources, widgets, related_questions, model, tokens_used)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
err := r.db.db.QueryRowContext(ctx, query,
|
||||
msg.ThreadID, msg.Role, msg.Content, sourcesJSON, widgetsJSON, relatedJSON, msg.Model, msg.TokensUsed,
|
||||
).Scan(&msg.ID, &msg.CreatedAt)
|
||||
|
||||
if err == nil {
|
||||
r.db.db.ExecContext(ctx, "UPDATE threads SET updated_at = NOW() WHERE id = $1", msg.ThreadID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GetMessages(ctx context.Context, threadID string, limit, offset int) ([]ThreadMessage, error) {
|
||||
query := `
|
||||
SELECT id, thread_id, role, content, sources, widgets, related_questions, model, tokens_used, created_at
|
||||
FROM thread_messages
|
||||
WHERE thread_id = $1
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, threadID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []ThreadMessage
|
||||
for rows.Next() {
|
||||
var msg ThreadMessage
|
||||
var sourcesJSON, widgetsJSON, relatedJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&msg.ID, &msg.ThreadID, &msg.Role, &msg.Content,
|
||||
&sourcesJSON, &widgetsJSON, &relatedJSON,
|
||||
&msg.Model, &msg.TokensUsed, &msg.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(sourcesJSON, &msg.Sources)
|
||||
json.Unmarshal(widgetsJSON, &msg.Widgets)
|
||||
json.Unmarshal(relatedJSON, &msg.RelatedQuestions)
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (r *ThreadRepository) GenerateTitle(ctx context.Context, threadID, firstMessage string) error {
|
||||
title := firstMessage
|
||||
if len(title) > 100 {
|
||||
title = title[:97] + "..."
|
||||
}
|
||||
_, err := r.db.db.ExecContext(ctx,
|
||||
"UPDATE threads SET title = $2 WHERE id = $1",
|
||||
threadID, title,
|
||||
)
|
||||
return err
|
||||
}
|
||||
323
backend/internal/db/user_interests_repo.go
Normal file
323
backend/internal/db/user_interests_repo.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserInterestsData struct {
|
||||
UserID string `json:"userId"`
|
||||
Topics json.RawMessage `json:"topics"`
|
||||
Sources json.RawMessage `json:"sources"`
|
||||
Keywords json.RawMessage `json:"keywords"`
|
||||
ViewHistory json.RawMessage `json:"viewHistory"`
|
||||
SavedArticles json.RawMessage `json:"savedArticles"`
|
||||
BlockedSources json.RawMessage `json:"blockedSources"`
|
||||
BlockedTopics json.RawMessage `json:"blockedTopics"`
|
||||
PreferredLang string `json:"preferredLang"`
|
||||
Region string `json:"region"`
|
||||
ReadingLevel string `json:"readingLevel"`
|
||||
Notifications json.RawMessage `json:"notifications"`
|
||||
CustomCategories json.RawMessage `json:"customCategories"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type UserInterestsRepository struct {
|
||||
db *PostgresDB
|
||||
}
|
||||
|
||||
func NewUserInterestsRepository(db *PostgresDB) *UserInterestsRepository {
|
||||
return &UserInterestsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) createTable(ctx context.Context) error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS user_interests (
|
||||
user_id VARCHAR(255) PRIMARY KEY,
|
||||
topics JSONB DEFAULT '{}',
|
||||
sources JSONB DEFAULT '{}',
|
||||
keywords JSONB DEFAULT '{}',
|
||||
view_history JSONB DEFAULT '[]',
|
||||
saved_articles JSONB DEFAULT '[]',
|
||||
blocked_sources JSONB DEFAULT '[]',
|
||||
blocked_topics JSONB DEFAULT '[]',
|
||||
preferred_lang VARCHAR(10) DEFAULT 'ru',
|
||||
region VARCHAR(50) DEFAULT 'russia',
|
||||
reading_level VARCHAR(20) DEFAULT 'general',
|
||||
notifications JSONB DEFAULT '{}',
|
||||
custom_categories JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_interests_updated ON user_interests(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_interests_region ON user_interests(region);
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) Get(ctx context.Context, userID string) (*UserInterestsData, error) {
|
||||
query := `
|
||||
SELECT user_id, topics, sources, keywords, view_history, saved_articles,
|
||||
blocked_sources, blocked_topics, preferred_lang, region, reading_level,
|
||||
notifications, custom_categories, created_at, updated_at
|
||||
FROM user_interests
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
row := r.db.DB().QueryRowContext(ctx, query, userID)
|
||||
|
||||
var data UserInterestsData
|
||||
err := row.Scan(
|
||||
&data.UserID, &data.Topics, &data.Sources, &data.Keywords,
|
||||
&data.ViewHistory, &data.SavedArticles, &data.BlockedSources,
|
||||
&data.BlockedTopics, &data.PreferredLang, &data.Region,
|
||||
&data.ReadingLevel, &data.Notifications, &data.CustomCategories,
|
||||
&data.CreatedAt, &data.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) Save(ctx context.Context, data *UserInterestsData) error {
|
||||
query := `
|
||||
INSERT INTO user_interests (
|
||||
user_id, topics, sources, keywords, view_history, saved_articles,
|
||||
blocked_sources, blocked_topics, preferred_lang, region, reading_level,
|
||||
notifications, custom_categories, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
topics = EXCLUDED.topics,
|
||||
sources = EXCLUDED.sources,
|
||||
keywords = EXCLUDED.keywords,
|
||||
view_history = EXCLUDED.view_history,
|
||||
saved_articles = EXCLUDED.saved_articles,
|
||||
blocked_sources = EXCLUDED.blocked_sources,
|
||||
blocked_topics = EXCLUDED.blocked_topics,
|
||||
preferred_lang = EXCLUDED.preferred_lang,
|
||||
region = EXCLUDED.region,
|
||||
reading_level = EXCLUDED.reading_level,
|
||||
notifications = EXCLUDED.notifications,
|
||||
custom_categories = EXCLUDED.custom_categories,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
now := time.Now()
|
||||
if data.CreatedAt.IsZero() {
|
||||
data.CreatedAt = now
|
||||
}
|
||||
data.UpdatedAt = now
|
||||
|
||||
if data.Topics == nil {
|
||||
data.Topics = json.RawMessage("{}")
|
||||
}
|
||||
if data.Sources == nil {
|
||||
data.Sources = json.RawMessage("{}")
|
||||
}
|
||||
if data.Keywords == nil {
|
||||
data.Keywords = json.RawMessage("{}")
|
||||
}
|
||||
if data.ViewHistory == nil {
|
||||
data.ViewHistory = json.RawMessage("[]")
|
||||
}
|
||||
if data.SavedArticles == nil {
|
||||
data.SavedArticles = json.RawMessage("[]")
|
||||
}
|
||||
if data.BlockedSources == nil {
|
||||
data.BlockedSources = json.RawMessage("[]")
|
||||
}
|
||||
if data.BlockedTopics == nil {
|
||||
data.BlockedTopics = json.RawMessage("[]")
|
||||
}
|
||||
if data.Notifications == nil {
|
||||
data.Notifications = json.RawMessage("{}")
|
||||
}
|
||||
if data.CustomCategories == nil {
|
||||
data.CustomCategories = json.RawMessage("[]")
|
||||
}
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query,
|
||||
data.UserID, data.Topics, data.Sources, data.Keywords,
|
||||
data.ViewHistory, data.SavedArticles, data.BlockedSources,
|
||||
data.BlockedTopics, data.PreferredLang, data.Region,
|
||||
data.ReadingLevel, data.Notifications, data.CustomCategories,
|
||||
data.CreatedAt, data.UpdatedAt,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) Delete(ctx context.Context, userID string) error {
|
||||
query := `DELETE FROM user_interests WHERE user_id = $1`
|
||||
_, err := r.db.DB().ExecContext(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) AddViewEvent(ctx context.Context, userID string, event json.RawMessage) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET view_history = CASE
|
||||
WHEN jsonb_array_length(view_history) >= 500
|
||||
THEN jsonb_build_array($2) || view_history[0:499]
|
||||
ELSE jsonb_build_array($2) || view_history
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
result, err := r.db.DB().ExecContext(ctx, query, userID, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
insertQuery := `
|
||||
INSERT INTO user_interests (user_id, view_history, updated_at)
|
||||
VALUES ($1, jsonb_build_array($2), NOW())
|
||||
`
|
||||
_, err = r.db.DB().ExecContext(ctx, insertQuery, userID, event)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) UpdateTopicScore(ctx context.Context, userID, topic string, delta float64) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET topics = topics || jsonb_build_object($2, COALESCE((topics->>$2)::float, 0) + $3),
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
result, err := r.db.DB().ExecContext(ctx, query, userID, topic, delta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
insertQuery := `
|
||||
INSERT INTO user_interests (user_id, topics, updated_at)
|
||||
VALUES ($1, jsonb_build_object($2, $3), NOW())
|
||||
`
|
||||
_, err = r.db.DB().ExecContext(ctx, insertQuery, userID, topic, delta)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) SaveArticle(ctx context.Context, userID, articleURL string) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET saved_articles = CASE
|
||||
WHEN NOT saved_articles ? $2
|
||||
THEN saved_articles || jsonb_build_array($2)
|
||||
ELSE saved_articles
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query, userID, articleURL)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) UnsaveArticle(ctx context.Context, userID, articleURL string) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET saved_articles = saved_articles - $2,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query, userID, articleURL)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) BlockSource(ctx context.Context, userID, source string) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET blocked_sources = CASE
|
||||
WHEN NOT blocked_sources ? $2
|
||||
THEN blocked_sources || jsonb_build_array($2)
|
||||
ELSE blocked_sources
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query, userID, source)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) UnblockSource(ctx context.Context, userID, source string) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET blocked_sources = blocked_sources - $2,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query, userID, source)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) GetTopUsers(ctx context.Context, limit int) ([]string, error) {
|
||||
query := `
|
||||
SELECT user_id FROM user_interests
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT $1
|
||||
`
|
||||
|
||||
rows, err := r.db.DB().QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
func (r *UserInterestsRepository) DecayAllInterests(ctx context.Context, decayFactor float64) error {
|
||||
query := `
|
||||
UPDATE user_interests
|
||||
SET topics = (
|
||||
SELECT jsonb_object_agg(key, (value::text::float * $1))
|
||||
FROM jsonb_each(topics) WHERE (value::text::float * $1) > 0.01
|
||||
),
|
||||
sources = (
|
||||
SELECT jsonb_object_agg(key, (value::text::float * $1))
|
||||
FROM jsonb_each(sources) WHERE (value::text::float * $1) > 0.01
|
||||
),
|
||||
keywords = (
|
||||
SELECT jsonb_object_agg(key, (value::text::float * $1))
|
||||
FROM jsonb_each(keywords) WHERE (value::text::float * $1) > 0.01
|
||||
),
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.DB().ExecContext(ctx, query, decayFactor)
|
||||
return err
|
||||
}
|
||||
691
backend/internal/discover/personalization.go
Normal file
691
backend/internal/discover/personalization.go
Normal file
@@ -0,0 +1,691 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserInterests struct {
|
||||
UserID string `json:"userId"`
|
||||
Topics map[string]float64 `json:"topics"`
|
||||
Sources map[string]float64 `json:"sources"`
|
||||
Keywords map[string]float64 `json:"keywords"`
|
||||
ViewHistory []ViewEvent `json:"viewHistory"`
|
||||
SavedArticles []string `json:"savedArticles"`
|
||||
BlockedSources []string `json:"blockedSources"`
|
||||
BlockedTopics []string `json:"blockedTopics"`
|
||||
PreferredLang string `json:"preferredLang"`
|
||||
Region string `json:"region"`
|
||||
ReadingLevel string `json:"readingLevel"`
|
||||
Notifications NotificationPrefs `json:"notifications"`
|
||||
LastUpdated time.Time `json:"lastUpdated"`
|
||||
CustomCategories []CustomCategory `json:"customCategories,omitempty"`
|
||||
}
|
||||
|
||||
type ViewEvent struct {
|
||||
ArticleID string `json:"articleId"`
|
||||
URL string `json:"url"`
|
||||
Topic string `json:"topic"`
|
||||
Source string `json:"source"`
|
||||
Keywords []string `json:"keywords"`
|
||||
TimeSpent int `json:"timeSpentSeconds"`
|
||||
Completed bool `json:"completed"`
|
||||
Saved bool `json:"saved"`
|
||||
Shared bool `json:"shared"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Engagement float64 `json:"engagement"`
|
||||
}
|
||||
|
||||
type NotificationPrefs struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DailyDigest bool `json:"dailyDigest"`
|
||||
DigestTime string `json:"digestTime"`
|
||||
BreakingNews bool `json:"breakingNews"`
|
||||
TopicAlerts []string `json:"topicAlerts"`
|
||||
Frequency string `json:"frequency"`
|
||||
}
|
||||
|
||||
type CustomCategory struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Keywords []string `json:"keywords"`
|
||||
Sources []string `json:"sources"`
|
||||
Weight float64 `json:"weight"`
|
||||
}
|
||||
|
||||
type PersonalizedFeed struct {
|
||||
UserID string `json:"userId"`
|
||||
Items []FeedItem `json:"items"`
|
||||
Categories []FeedCategory `json:"categories"`
|
||||
TrendingIn []string `json:"trendingIn"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
NextUpdate time.Time `json:"nextUpdate"`
|
||||
}
|
||||
|
||||
type FeedItem struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
Source string `json:"source"`
|
||||
SourceLogo string `json:"sourceLogo"`
|
||||
Topic string `json:"topic"`
|
||||
Keywords []string `json:"keywords"`
|
||||
PublishedAt time.Time `json:"publishedAt"`
|
||||
RelevanceScore float64 `json:"relevanceScore"`
|
||||
Reason string `json:"reason"`
|
||||
SourcesCount int `json:"sourcesCount"`
|
||||
ReadTime int `json:"readTimeMinutes"`
|
||||
HasDigest bool `json:"hasDigest"`
|
||||
IsBreaking bool `json:"isBreaking"`
|
||||
IsTrending bool `json:"isTrending"`
|
||||
IsSaved bool `json:"isSaved"`
|
||||
IsRead bool `json:"isRead"`
|
||||
}
|
||||
|
||||
type FeedCategory struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
Color string `json:"color"`
|
||||
Items []FeedItem `json:"items"`
|
||||
IsCustom bool `json:"isCustom"`
|
||||
}
|
||||
|
||||
type PersonalizationEngine struct {
|
||||
userStore UserInterestStore
|
||||
contentRepo ContentRepository
|
||||
mu sync.RWMutex
|
||||
config PersonalizationConfig
|
||||
}
|
||||
|
||||
type PersonalizationConfig struct {
|
||||
MaxFeedItems int
|
||||
DecayFactor float64
|
||||
RecencyWeight float64
|
||||
EngagementWeight float64
|
||||
TopicMatchWeight float64
|
||||
SourceTrustWeight float64
|
||||
DiversityFactor float64
|
||||
TrendingBoost float64
|
||||
BreakingBoost float64
|
||||
}
|
||||
|
||||
type UserInterestStore interface {
|
||||
Get(ctx context.Context, userID string) (*UserInterests, error)
|
||||
Save(ctx context.Context, interests *UserInterests) error
|
||||
Delete(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
type ContentRepository interface {
|
||||
GetLatestContent(ctx context.Context, topics []string, limit int) ([]FeedItem, error)
|
||||
GetTrending(ctx context.Context, region string, limit int) ([]FeedItem, error)
|
||||
GetByKeywords(ctx context.Context, keywords []string, limit int) ([]FeedItem, error)
|
||||
}
|
||||
|
||||
func DefaultConfig() PersonalizationConfig {
|
||||
return PersonalizationConfig{
|
||||
MaxFeedItems: 50,
|
||||
DecayFactor: 0.95,
|
||||
RecencyWeight: 0.25,
|
||||
EngagementWeight: 0.20,
|
||||
TopicMatchWeight: 0.30,
|
||||
SourceTrustWeight: 0.15,
|
||||
DiversityFactor: 0.10,
|
||||
TrendingBoost: 1.5,
|
||||
BreakingBoost: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPersonalizationEngine(userStore UserInterestStore, contentRepo ContentRepository, cfg PersonalizationConfig) *PersonalizationEngine {
|
||||
return &PersonalizationEngine{
|
||||
userStore: userStore,
|
||||
contentRepo: contentRepo,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) GenerateForYouFeed(ctx context.Context, userID string) (*PersonalizedFeed, error) {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
interests = &UserInterests{
|
||||
UserID: userID,
|
||||
Topics: make(map[string]float64),
|
||||
Sources: make(map[string]float64),
|
||||
Keywords: make(map[string]float64),
|
||||
PreferredLang: "ru",
|
||||
Region: "russia",
|
||||
}
|
||||
}
|
||||
|
||||
var allItems []FeedItem
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
topTopics := e.getTopInterests(interests.Topics, 5)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
items, _ := e.contentRepo.GetLatestContent(ctx, topTopics, 30)
|
||||
mu.Lock()
|
||||
allItems = append(allItems, items...)
|
||||
mu.Unlock()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
items, _ := e.contentRepo.GetTrending(ctx, interests.Region, 20)
|
||||
for i := range items {
|
||||
items[i].IsTrending = true
|
||||
}
|
||||
mu.Lock()
|
||||
allItems = append(allItems, items...)
|
||||
mu.Unlock()
|
||||
}()
|
||||
|
||||
topKeywords := e.getTopKeywords(interests.Keywords, 10)
|
||||
if len(topKeywords) > 0 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
items, _ := e.contentRepo.GetByKeywords(ctx, topKeywords, 15)
|
||||
mu.Lock()
|
||||
allItems = append(allItems, items...)
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
allItems = e.deduplicateItems(allItems)
|
||||
allItems = e.filterBlockedContent(allItems, interests)
|
||||
|
||||
for i := range allItems {
|
||||
allItems[i].RelevanceScore = e.calculateRelevance(allItems[i], interests)
|
||||
allItems[i].Reason = e.explainRecommendation(allItems[i], interests)
|
||||
allItems[i].IsRead = e.isArticleRead(allItems[i].URL, interests)
|
||||
allItems[i].IsSaved = e.isArticleSaved(allItems[i].URL, interests)
|
||||
}
|
||||
|
||||
sort.Slice(allItems, func(i, j int) bool {
|
||||
return allItems[i].RelevanceScore > allItems[j].RelevanceScore
|
||||
})
|
||||
|
||||
allItems = e.applyDiversity(allItems)
|
||||
|
||||
if len(allItems) > e.config.MaxFeedItems {
|
||||
allItems = allItems[:e.config.MaxFeedItems]
|
||||
}
|
||||
|
||||
categories := e.groupByCategory(allItems, interests)
|
||||
|
||||
return &PersonalizedFeed{
|
||||
UserID: userID,
|
||||
Items: allItems,
|
||||
Categories: categories,
|
||||
TrendingIn: topTopics,
|
||||
UpdatedAt: time.Now(),
|
||||
NextUpdate: time.Now().Add(15 * time.Minute),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) RecordView(ctx context.Context, userID string, event ViewEvent) error {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
interests = &UserInterests{
|
||||
UserID: userID,
|
||||
Topics: make(map[string]float64),
|
||||
Sources: make(map[string]float64),
|
||||
Keywords: make(map[string]float64),
|
||||
}
|
||||
}
|
||||
|
||||
event.Engagement = e.calculateEngagement(event)
|
||||
|
||||
interests.ViewHistory = append([]ViewEvent{event}, interests.ViewHistory...)
|
||||
if len(interests.ViewHistory) > 500 {
|
||||
interests.ViewHistory = interests.ViewHistory[:500]
|
||||
}
|
||||
|
||||
topicWeight := event.Engagement * 0.1
|
||||
interests.Topics[event.Topic] += topicWeight
|
||||
|
||||
sourceWeight := event.Engagement * 0.05
|
||||
interests.Sources[event.Source] += sourceWeight
|
||||
|
||||
keywordWeight := event.Engagement * 0.02
|
||||
for _, kw := range event.Keywords {
|
||||
interests.Keywords[kw] += keywordWeight
|
||||
}
|
||||
|
||||
if event.Saved {
|
||||
interests.SavedArticles = append(interests.SavedArticles, event.URL)
|
||||
}
|
||||
|
||||
interests.LastUpdated = time.Now()
|
||||
|
||||
e.decayInterests(interests)
|
||||
|
||||
return e.userStore.Save(ctx, interests)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) UpdateTopicPreference(ctx context.Context, userID, topic string, weight float64) error {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
interests = &UserInterests{
|
||||
UserID: userID,
|
||||
Topics: make(map[string]float64),
|
||||
Sources: make(map[string]float64),
|
||||
Keywords: make(map[string]float64),
|
||||
}
|
||||
}
|
||||
|
||||
interests.Topics[topic] = weight
|
||||
interests.LastUpdated = time.Now()
|
||||
|
||||
return e.userStore.Save(ctx, interests)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) BlockSource(ctx context.Context, userID, source string) error {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, blocked := range interests.BlockedSources {
|
||||
if blocked == source {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
interests.BlockedSources = append(interests.BlockedSources, source)
|
||||
interests.LastUpdated = time.Now()
|
||||
|
||||
return e.userStore.Save(ctx, interests)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) BlockTopic(ctx context.Context, userID, topic string) error {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, blocked := range interests.BlockedTopics {
|
||||
if blocked == topic {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
interests.BlockedTopics = append(interests.BlockedTopics, topic)
|
||||
delete(interests.Topics, topic)
|
||||
interests.LastUpdated = time.Now()
|
||||
|
||||
return e.userStore.Save(ctx, interests)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) AddCustomCategory(ctx context.Context, userID string, category CustomCategory) error {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
interests.CustomCategories = append(interests.CustomCategories, category)
|
||||
interests.LastUpdated = time.Now()
|
||||
|
||||
return e.userStore.Save(ctx, interests)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) GetUserTopics(ctx context.Context, userID string) (map[string]float64, error) {
|
||||
interests, err := e.userStore.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return interests.Topics, nil
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) calculateRelevance(item FeedItem, interests *UserInterests) float64 {
|
||||
score := 0.0
|
||||
|
||||
if topicScore, ok := interests.Topics[item.Topic]; ok {
|
||||
score += topicScore * e.config.TopicMatchWeight
|
||||
}
|
||||
|
||||
if sourceScore, ok := interests.Sources[item.Source]; ok {
|
||||
score += sourceScore * e.config.SourceTrustWeight
|
||||
}
|
||||
|
||||
keywordScore := 0.0
|
||||
for _, kw := range item.Keywords {
|
||||
if kwScore, ok := interests.Keywords[strings.ToLower(kw)]; ok {
|
||||
keywordScore += kwScore
|
||||
}
|
||||
}
|
||||
score += keywordScore * 0.1
|
||||
|
||||
hoursSincePublish := time.Since(item.PublishedAt).Hours()
|
||||
recencyScore := math.Max(0, 1.0-hoursSincePublish/168.0)
|
||||
score += recencyScore * e.config.RecencyWeight
|
||||
|
||||
if item.IsTrending {
|
||||
score *= e.config.TrendingBoost
|
||||
}
|
||||
|
||||
if item.IsBreaking {
|
||||
score *= e.config.BreakingBoost
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) calculateEngagement(event ViewEvent) float64 {
|
||||
engagement := 0.0
|
||||
|
||||
if event.TimeSpent > 0 {
|
||||
readTimeScore := math.Min(1.0, float64(event.TimeSpent)/300.0)
|
||||
engagement += readTimeScore * 0.4
|
||||
}
|
||||
|
||||
if event.Completed {
|
||||
engagement += 0.3
|
||||
}
|
||||
|
||||
if event.Saved {
|
||||
engagement += 0.2
|
||||
}
|
||||
|
||||
if event.Shared {
|
||||
engagement += 0.1
|
||||
}
|
||||
|
||||
return engagement
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) explainRecommendation(item FeedItem, interests *UserInterests) string {
|
||||
if item.IsBreaking {
|
||||
return "Срочная новость"
|
||||
}
|
||||
|
||||
if item.IsTrending {
|
||||
return "Популярно сейчас"
|
||||
}
|
||||
|
||||
if topicScore, ok := interests.Topics[item.Topic]; ok && topicScore > 0.5 {
|
||||
return fmt.Sprintf("Из вашей категории: %s", item.Topic)
|
||||
}
|
||||
|
||||
if sourceScore, ok := interests.Sources[item.Source]; ok && sourceScore > 0.3 {
|
||||
return fmt.Sprintf("Из источника, который вы читаете: %s", item.Source)
|
||||
}
|
||||
|
||||
for _, kw := range item.Keywords {
|
||||
if kwScore, ok := interests.Keywords[strings.ToLower(kw)]; ok && kwScore > 0.2 {
|
||||
return fmt.Sprintf("По вашему интересу: %s", kw)
|
||||
}
|
||||
}
|
||||
|
||||
return "Рекомендуем для вас"
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) getTopInterests(interests map[string]float64, limit int) []string {
|
||||
type kv struct {
|
||||
Key string
|
||||
Value float64
|
||||
}
|
||||
|
||||
var sorted []kv
|
||||
for k, v := range interests {
|
||||
sorted = append(sorted, kv{k, v})
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Value > sorted[j].Value
|
||||
})
|
||||
|
||||
result := make([]string, 0, limit)
|
||||
for i, item := range sorted {
|
||||
if i >= limit {
|
||||
break
|
||||
}
|
||||
result = append(result, item.Key)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) getTopKeywords(keywords map[string]float64, limit int) []string {
|
||||
return e.getTopInterests(keywords, limit)
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) deduplicateItems(items []FeedItem) []FeedItem {
|
||||
seen := make(map[string]bool)
|
||||
result := make([]FeedItem, 0, len(items))
|
||||
|
||||
for _, item := range items {
|
||||
if !seen[item.URL] {
|
||||
seen[item.URL] = true
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) filterBlockedContent(items []FeedItem, interests *UserInterests) []FeedItem {
|
||||
blockedSources := make(map[string]bool)
|
||||
for _, s := range interests.BlockedSources {
|
||||
blockedSources[strings.ToLower(s)] = true
|
||||
}
|
||||
|
||||
blockedTopics := make(map[string]bool)
|
||||
for _, t := range interests.BlockedTopics {
|
||||
blockedTopics[strings.ToLower(t)] = true
|
||||
}
|
||||
|
||||
result := make([]FeedItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
if blockedSources[strings.ToLower(item.Source)] {
|
||||
continue
|
||||
}
|
||||
if blockedTopics[strings.ToLower(item.Topic)] {
|
||||
continue
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) applyDiversity(items []FeedItem) []FeedItem {
|
||||
if len(items) <= 10 {
|
||||
return items
|
||||
}
|
||||
|
||||
topicCounts := make(map[string]int)
|
||||
sourceCounts := make(map[string]int)
|
||||
maxPerTopic := len(items) / 5
|
||||
maxPerSource := len(items) / 4
|
||||
|
||||
if maxPerTopic < 3 {
|
||||
maxPerTopic = 3
|
||||
}
|
||||
if maxPerSource < 3 {
|
||||
maxPerSource = 3
|
||||
}
|
||||
|
||||
result := make([]FeedItem, 0, len(items))
|
||||
deferred := make([]FeedItem, 0)
|
||||
|
||||
for _, item := range items {
|
||||
if topicCounts[item.Topic] >= maxPerTopic || sourceCounts[item.Source] >= maxPerSource {
|
||||
deferred = append(deferred, item)
|
||||
continue
|
||||
}
|
||||
|
||||
topicCounts[item.Topic]++
|
||||
sourceCounts[item.Source]++
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
for _, item := range deferred {
|
||||
if len(result) >= e.config.MaxFeedItems {
|
||||
break
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) groupByCategory(items []FeedItem, interests *UserInterests) []FeedCategory {
|
||||
categoryMap := make(map[string][]FeedItem)
|
||||
|
||||
for _, item := range items {
|
||||
categoryMap[item.Topic] = append(categoryMap[item.Topic], item)
|
||||
}
|
||||
|
||||
categories := make([]FeedCategory, 0, len(categoryMap))
|
||||
categoryMeta := map[string]struct {
|
||||
Icon string
|
||||
Color string
|
||||
}{
|
||||
"tech": {"💻", "#3B82F6"},
|
||||
"finance": {"💰", "#10B981"},
|
||||
"sports": {"⚽", "#F59E0B"},
|
||||
"politics": {"🏛️", "#6366F1"},
|
||||
"science": {"🔬", "#8B5CF6"},
|
||||
"health": {"🏥", "#EC4899"},
|
||||
"entertainment": {"🎬", "#F97316"},
|
||||
"world": {"🌍", "#14B8A6"},
|
||||
"business": {"📊", "#6B7280"},
|
||||
"culture": {"🎭", "#A855F7"},
|
||||
}
|
||||
|
||||
for topic, topicItems := range categoryMap {
|
||||
if len(topicItems) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
meta, ok := categoryMeta[strings.ToLower(topic)]
|
||||
if !ok {
|
||||
meta = struct {
|
||||
Icon string
|
||||
Color string
|
||||
}{"📰", "#6B7280"}
|
||||
}
|
||||
|
||||
categories = append(categories, FeedCategory{
|
||||
ID: topic,
|
||||
Name: topic,
|
||||
Icon: meta.Icon,
|
||||
Color: meta.Color,
|
||||
Items: topicItems,
|
||||
})
|
||||
}
|
||||
|
||||
for _, custom := range interests.CustomCategories {
|
||||
customItems := make([]FeedItem, 0)
|
||||
for _, item := range items {
|
||||
for _, kw := range custom.Keywords {
|
||||
if containsKeyword(item, kw) {
|
||||
customItems = append(customItems, item)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(customItems) > 0 {
|
||||
categories = append(categories, FeedCategory{
|
||||
ID: custom.ID,
|
||||
Name: custom.Name,
|
||||
Icon: "⭐",
|
||||
Color: "#FBBF24",
|
||||
Items: customItems,
|
||||
IsCustom: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(categories, func(i, j int) bool {
|
||||
iScore := interests.Topics[categories[i].ID]
|
||||
jScore := interests.Topics[categories[j].ID]
|
||||
return iScore > jScore
|
||||
})
|
||||
|
||||
return categories
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) decayInterests(interests *UserInterests) {
|
||||
for k := range interests.Topics {
|
||||
interests.Topics[k] *= e.config.DecayFactor
|
||||
if interests.Topics[k] < 0.01 {
|
||||
delete(interests.Topics, k)
|
||||
}
|
||||
}
|
||||
|
||||
for k := range interests.Sources {
|
||||
interests.Sources[k] *= e.config.DecayFactor
|
||||
if interests.Sources[k] < 0.01 {
|
||||
delete(interests.Sources, k)
|
||||
}
|
||||
}
|
||||
|
||||
for k := range interests.Keywords {
|
||||
interests.Keywords[k] *= e.config.DecayFactor
|
||||
if interests.Keywords[k] < 0.01 {
|
||||
delete(interests.Keywords, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) isArticleRead(url string, interests *UserInterests) bool {
|
||||
for _, event := range interests.ViewHistory {
|
||||
if event.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *PersonalizationEngine) isArticleSaved(url string, interests *UserInterests) bool {
|
||||
for _, saved := range interests.SavedArticles {
|
||||
if saved == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsKeyword(item FeedItem, keyword string) bool {
|
||||
kw := strings.ToLower(keyword)
|
||||
if strings.Contains(strings.ToLower(item.Title), kw) {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(item.Summary), kw) {
|
||||
return true
|
||||
}
|
||||
for _, itemKw := range item.Keywords {
|
||||
if strings.ToLower(itemKw) == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *UserInterests) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(u)
|
||||
}
|
||||
|
||||
func ParseUserInterests(data []byte) (*UserInterests, error) {
|
||||
var interests UserInterests
|
||||
if err := json.Unmarshal(data, &interests); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &interests, nil
|
||||
}
|
||||
343
backend/internal/files/analyzer.go
Normal file
343
backend/internal/files/analyzer.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/ledongthuc/pdf"
|
||||
)
|
||||
|
||||
type FileAnalyzer struct {
|
||||
llmClient llm.Client
|
||||
storagePath string
|
||||
}
|
||||
|
||||
type AnalysisResult struct {
|
||||
FileType string `json:"fileType"`
|
||||
ExtractedText string `json:"extractedText"`
|
||||
Summary string `json:"summary"`
|
||||
KeyPoints []string `json:"keyPoints"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
func NewFileAnalyzer(llmClient llm.Client, storagePath string) *FileAnalyzer {
|
||||
if storagePath == "" {
|
||||
storagePath = "/tmp/gooseek-files"
|
||||
}
|
||||
os.MkdirAll(storagePath, 0755)
|
||||
|
||||
return &FileAnalyzer{
|
||||
llmClient: llmClient,
|
||||
storagePath: storagePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) AnalyzeFile(ctx context.Context, filePath string, fileType string) (*AnalysisResult, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(fileType, "application/pdf"):
|
||||
return fa.analyzePDF(ctx, filePath)
|
||||
case strings.HasPrefix(fileType, "image/"):
|
||||
return fa.analyzeImage(ctx, filePath, fileType)
|
||||
case strings.HasPrefix(fileType, "text/"):
|
||||
return fa.analyzeText(ctx, filePath)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported file type: %s", fileType)
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) analyzePDF(ctx context.Context, filePath string) (*AnalysisResult, error) {
|
||||
text, metadata, err := extractPDFContent(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract PDF content: %w", err)
|
||||
}
|
||||
|
||||
if len(text) > 50000 {
|
||||
text = text[:50000] + "\n\n[Content truncated...]"
|
||||
}
|
||||
|
||||
summary, keyPoints, err := fa.generateSummary(ctx, text, "PDF document")
|
||||
if err != nil {
|
||||
summary = ""
|
||||
keyPoints = nil
|
||||
}
|
||||
|
||||
return &AnalysisResult{
|
||||
FileType: "pdf",
|
||||
ExtractedText: text,
|
||||
Summary: summary,
|
||||
KeyPoints: keyPoints,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractPDFContent(filePath string) (string, map[string]interface{}, error) {
|
||||
f, r, err := pdf.Open(filePath)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var textBuilder strings.Builder
|
||||
numPages := r.NumPage()
|
||||
|
||||
for i := 1; i <= numPages; i++ {
|
||||
p := r.Page(i)
|
||||
if p.V.IsNull() {
|
||||
continue
|
||||
}
|
||||
|
||||
text, err := p.GetPlainText(nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
textBuilder.WriteString(text)
|
||||
textBuilder.WriteString("\n\n")
|
||||
|
||||
if textBuilder.Len() > 100000 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
metadata := map[string]interface{}{
|
||||
"numPages": numPages,
|
||||
}
|
||||
|
||||
return textBuilder.String(), metadata, nil
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) analyzeImage(ctx context.Context, filePath string, mimeType string) (*AnalysisResult, error) {
|
||||
imageData, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read image: %w", err)
|
||||
}
|
||||
|
||||
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
||||
|
||||
description, err := fa.describeImage(ctx, base64Image, mimeType)
|
||||
if err != nil {
|
||||
description = "Image analysis unavailable"
|
||||
}
|
||||
|
||||
metadata := map[string]interface{}{
|
||||
"size": len(imageData),
|
||||
}
|
||||
|
||||
return &AnalysisResult{
|
||||
FileType: "image",
|
||||
ExtractedText: description,
|
||||
Summary: description,
|
||||
KeyPoints: extractKeyPointsFromDescription(description),
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) describeImage(ctx context.Context, base64Image, mimeType string) (string, error) {
|
||||
prompt := `Analyze this image and provide:
|
||||
1. A detailed description of what's shown
|
||||
2. Any text visible in the image (OCR)
|
||||
3. Key elements and their relationships
|
||||
4. Any data, charts, or diagrams and their meaning
|
||||
|
||||
Be thorough but concise.`
|
||||
|
||||
messages := []llm.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Images: []llm.ImageContent{
|
||||
{
|
||||
Type: mimeType,
|
||||
Data: base64Image,
|
||||
IsBase64: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := fa.llmClient.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) analyzeText(ctx context.Context, filePath string) (*AnalysisResult, error) {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
text := string(content)
|
||||
if len(text) > 50000 {
|
||||
text = text[:50000] + "\n\n[Content truncated...]"
|
||||
}
|
||||
|
||||
summary, keyPoints, err := fa.generateSummary(ctx, text, "text document")
|
||||
if err != nil {
|
||||
summary = ""
|
||||
keyPoints = nil
|
||||
}
|
||||
|
||||
return &AnalysisResult{
|
||||
FileType: "text",
|
||||
ExtractedText: text,
|
||||
Summary: summary,
|
||||
KeyPoints: keyPoints,
|
||||
Metadata: map[string]interface{}{
|
||||
"size": len(content),
|
||||
"lineCount": strings.Count(text, "\n") + 1,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) generateSummary(ctx context.Context, text, docType string) (string, []string, error) {
|
||||
if len(text) < 100 {
|
||||
return text, nil, nil
|
||||
}
|
||||
|
||||
truncatedText := text
|
||||
if len(text) > 15000 {
|
||||
truncatedText = text[:15000] + "\n\n[Content truncated for analysis...]"
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Analyze this %s and provide:
|
||||
|
||||
1. A concise summary (2-3 paragraphs)
|
||||
2. 5-7 key points as bullet points
|
||||
|
||||
Document content:
|
||||
%s
|
||||
|
||||
Format your response as:
|
||||
SUMMARY:
|
||||
[your summary here]
|
||||
|
||||
KEY POINTS:
|
||||
- [point 1]
|
||||
- [point 2]
|
||||
...`, docType, truncatedText)
|
||||
|
||||
result, err := fa.llmClient.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
summary, keyPoints := parseSummaryResponse(result)
|
||||
return summary, keyPoints, nil
|
||||
}
|
||||
|
||||
func parseSummaryResponse(response string) (string, []string) {
|
||||
var summary string
|
||||
var keyPoints []string
|
||||
|
||||
parts := strings.Split(response, "KEY POINTS:")
|
||||
if len(parts) >= 2 {
|
||||
summaryPart := strings.TrimPrefix(parts[0], "SUMMARY:")
|
||||
summary = strings.TrimSpace(summaryPart)
|
||||
|
||||
keyPointsPart := parts[1]
|
||||
for _, line := range strings.Split(keyPointsPart, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "•") || strings.HasPrefix(line, "*") {
|
||||
point := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "•"), "*")
|
||||
point = strings.TrimSpace(point)
|
||||
if point != "" {
|
||||
keyPoints = append(keyPoints, point)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
summary = response
|
||||
}
|
||||
|
||||
return summary, keyPoints
|
||||
}
|
||||
|
||||
func extractKeyPointsFromDescription(description string) []string {
|
||||
var points []string
|
||||
sentences := strings.Split(description, ".")
|
||||
|
||||
for i, s := range sentences {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) > 20 && i < 5 {
|
||||
points = append(points, s+".")
|
||||
}
|
||||
}
|
||||
|
||||
return points
|
||||
}
|
||||
|
||||
func DetectMimeType(filename string, content []byte) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
case ".txt":
|
||||
return "text/plain"
|
||||
case ".md":
|
||||
return "text/markdown"
|
||||
case ".csv":
|
||||
return "text/csv"
|
||||
case ".json":
|
||||
return "application/json"
|
||||
default:
|
||||
return http.DetectContentType(content[:min(512, len(content))])
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) SaveFile(filename string, content io.Reader) (string, int64, error) {
|
||||
safeName := filepath.Base(filename)
|
||||
destPath := filepath.Join(fa.storagePath, safeName)
|
||||
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
size, err := io.Copy(io.MultiWriter(file, &buf), content)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return destPath, size, nil
|
||||
}
|
||||
|
||||
func (fa *FileAnalyzer) DeleteFile(filePath string) error {
|
||||
if !strings.HasPrefix(filePath, fa.storagePath) {
|
||||
return fmt.Errorf("invalid file path")
|
||||
}
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
537
backend/internal/finance/heatmap.go
Normal file
537
backend/internal/finance/heatmap.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package finance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HeatmapService struct {
|
||||
cache map[string]*CachedHeatmap
|
||||
mu sync.RWMutex
|
||||
httpClient *http.Client
|
||||
config HeatmapConfig
|
||||
}
|
||||
|
||||
type HeatmapConfig struct {
|
||||
DataProviderURL string
|
||||
CacheTTL time.Duration
|
||||
RefreshInterval time.Duration
|
||||
}
|
||||
|
||||
type CachedHeatmap struct {
|
||||
Data *MarketHeatmap
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type MarketHeatmap struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Type HeatmapType `json:"type"`
|
||||
Market string `json:"market"`
|
||||
Sectors []Sector `json:"sectors"`
|
||||
Tickers []TickerData `json:"tickers"`
|
||||
Summary MarketSummary `json:"summary"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
TimeRange string `json:"timeRange"`
|
||||
Colorscale Colorscale `json:"colorscale"`
|
||||
}
|
||||
|
||||
type HeatmapType string
|
||||
|
||||
const (
|
||||
HeatmapTreemap HeatmapType = "treemap"
|
||||
HeatmapGrid HeatmapType = "grid"
|
||||
HeatmapBubble HeatmapType = "bubble"
|
||||
HeatmapSectorChart HeatmapType = "sector_chart"
|
||||
)
|
||||
|
||||
type Sector struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Change float64 `json:"change"`
|
||||
MarketCap float64 `json:"marketCap"`
|
||||
Volume float64 `json:"volume"`
|
||||
TickerCount int `json:"tickerCount"`
|
||||
TopGainers []TickerData `json:"topGainers,omitempty"`
|
||||
TopLosers []TickerData `json:"topLosers,omitempty"`
|
||||
Color string `json:"color"`
|
||||
Weight float64 `json:"weight"`
|
||||
}
|
||||
|
||||
type TickerData struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Name string `json:"name"`
|
||||
Price float64 `json:"price"`
|
||||
Change float64 `json:"change"`
|
||||
ChangePercent float64 `json:"changePercent"`
|
||||
Volume float64 `json:"volume"`
|
||||
MarketCap float64 `json:"marketCap"`
|
||||
Sector string `json:"sector"`
|
||||
Industry string `json:"industry"`
|
||||
Color string `json:"color"`
|
||||
Size float64 `json:"size"`
|
||||
PrevClose float64 `json:"prevClose,omitempty"`
|
||||
DayHigh float64 `json:"dayHigh,omitempty"`
|
||||
DayLow float64 `json:"dayLow,omitempty"`
|
||||
Week52High float64 `json:"week52High,omitempty"`
|
||||
Week52Low float64 `json:"week52Low,omitempty"`
|
||||
PE float64 `json:"pe,omitempty"`
|
||||
EPS float64 `json:"eps,omitempty"`
|
||||
Dividend float64 `json:"dividend,omitempty"`
|
||||
DividendYield float64 `json:"dividendYield,omitempty"`
|
||||
}
|
||||
|
||||
type MarketSummary struct {
|
||||
TotalMarketCap float64 `json:"totalMarketCap"`
|
||||
TotalVolume float64 `json:"totalVolume"`
|
||||
AdvancingCount int `json:"advancingCount"`
|
||||
DecliningCount int `json:"decliningCount"`
|
||||
UnchangedCount int `json:"unchangedCount"`
|
||||
AverageChange float64 `json:"averageChange"`
|
||||
TopGainer *TickerData `json:"topGainer,omitempty"`
|
||||
TopLoser *TickerData `json:"topLoser,omitempty"`
|
||||
MostActive *TickerData `json:"mostActive,omitempty"`
|
||||
MarketSentiment string `json:"marketSentiment"`
|
||||
VIX float64 `json:"vix,omitempty"`
|
||||
FearGreedIndex int `json:"fearGreedIndex,omitempty"`
|
||||
}
|
||||
|
||||
type Colorscale struct {
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
MidPoint float64 `json:"midPoint"`
|
||||
Colors []string `json:"colors"`
|
||||
Thresholds []float64 `json:"thresholds"`
|
||||
}
|
||||
|
||||
var DefaultColorscale = Colorscale{
|
||||
Min: -10,
|
||||
Max: 10,
|
||||
MidPoint: 0,
|
||||
Colors: []string{
|
||||
"#ef4444",
|
||||
"#f87171",
|
||||
"#fca5a5",
|
||||
"#fecaca",
|
||||
"#e5e7eb",
|
||||
"#bbf7d0",
|
||||
"#86efac",
|
||||
"#4ade80",
|
||||
"#22c55e",
|
||||
},
|
||||
Thresholds: []float64{-5, -3, -2, -1, 1, 2, 3, 5},
|
||||
}
|
||||
|
||||
func NewHeatmapService(cfg HeatmapConfig) *HeatmapService {
|
||||
if cfg.CacheTTL == 0 {
|
||||
cfg.CacheTTL = 5 * time.Minute
|
||||
}
|
||||
if cfg.RefreshInterval == 0 {
|
||||
cfg.RefreshInterval = time.Minute
|
||||
}
|
||||
|
||||
return &HeatmapService{
|
||||
cache: make(map[string]*CachedHeatmap),
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeatmapService) GetMarketHeatmap(ctx context.Context, market string, timeRange string) (*MarketHeatmap, error) {
|
||||
cacheKey := fmt.Sprintf("%s:%s", market, timeRange)
|
||||
|
||||
s.mu.RLock()
|
||||
if cached, ok := s.cache[cacheKey]; ok && time.Now().Before(cached.ExpiresAt) {
|
||||
s.mu.RUnlock()
|
||||
return cached.Data, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
heatmap, err := s.fetchMarketData(ctx, market, timeRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cache[cacheKey] = &CachedHeatmap{
|
||||
Data: heatmap,
|
||||
ExpiresAt: time.Now().Add(s.config.CacheTTL),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return heatmap, nil
|
||||
}
|
||||
|
||||
func (s *HeatmapService) GetSectorHeatmap(ctx context.Context, market, sector, timeRange string) (*MarketHeatmap, error) {
|
||||
heatmap, err := s.GetMarketHeatmap(ctx, market, timeRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filteredTickers := make([]TickerData, 0)
|
||||
for _, t := range heatmap.Tickers {
|
||||
if strings.EqualFold(t.Sector, sector) {
|
||||
filteredTickers = append(filteredTickers, t)
|
||||
}
|
||||
}
|
||||
|
||||
sectorHeatmap := &MarketHeatmap{
|
||||
ID: fmt.Sprintf("%s-%s", market, sector),
|
||||
Title: fmt.Sprintf("%s - %s", market, sector),
|
||||
Type: HeatmapTreemap,
|
||||
Market: market,
|
||||
Tickers: filteredTickers,
|
||||
TimeRange: timeRange,
|
||||
UpdatedAt: time.Now(),
|
||||
Colorscale: DefaultColorscale,
|
||||
}
|
||||
|
||||
sectorHeatmap.Summary = s.calculateSummary(filteredTickers)
|
||||
|
||||
return sectorHeatmap, nil
|
||||
}
|
||||
|
||||
func (s *HeatmapService) fetchMarketData(ctx context.Context, market, timeRange string) (*MarketHeatmap, error) {
|
||||
heatmap := s.generateMockMarketData(market)
|
||||
heatmap.TimeRange = timeRange
|
||||
|
||||
return heatmap, nil
|
||||
}
|
||||
|
||||
func (s *HeatmapService) generateMockMarketData(market string) *MarketHeatmap {
|
||||
sectors := []struct {
|
||||
name string
|
||||
tickers []struct{ symbol, name string }
|
||||
}{
|
||||
{"Technology", []struct{ symbol, name string }{
|
||||
{"AAPL", "Apple Inc."},
|
||||
{"MSFT", "Microsoft Corp."},
|
||||
{"GOOGL", "Alphabet Inc."},
|
||||
{"AMZN", "Amazon.com Inc."},
|
||||
{"META", "Meta Platforms"},
|
||||
{"NVDA", "NVIDIA Corp."},
|
||||
{"TSLA", "Tesla Inc."},
|
||||
}},
|
||||
{"Healthcare", []struct{ symbol, name string }{
|
||||
{"JNJ", "Johnson & Johnson"},
|
||||
{"UNH", "UnitedHealth Group"},
|
||||
{"PFE", "Pfizer Inc."},
|
||||
{"MRK", "Merck & Co."},
|
||||
{"ABBV", "AbbVie Inc."},
|
||||
}},
|
||||
{"Finance", []struct{ symbol, name string }{
|
||||
{"JPM", "JPMorgan Chase"},
|
||||
{"BAC", "Bank of America"},
|
||||
{"WFC", "Wells Fargo"},
|
||||
{"GS", "Goldman Sachs"},
|
||||
{"MS", "Morgan Stanley"},
|
||||
}},
|
||||
{"Energy", []struct{ symbol, name string }{
|
||||
{"XOM", "Exxon Mobil"},
|
||||
{"CVX", "Chevron Corp."},
|
||||
{"COP", "ConocoPhillips"},
|
||||
{"SLB", "Schlumberger"},
|
||||
}},
|
||||
{"Consumer", []struct{ symbol, name string }{
|
||||
{"WMT", "Walmart Inc."},
|
||||
{"PG", "Procter & Gamble"},
|
||||
{"KO", "Coca-Cola Co."},
|
||||
{"PEP", "PepsiCo Inc."},
|
||||
{"COST", "Costco Wholesale"},
|
||||
}},
|
||||
}
|
||||
|
||||
allTickers := make([]TickerData, 0)
|
||||
allSectors := make([]Sector, 0)
|
||||
|
||||
for _, sec := range sectors {
|
||||
sectorTickers := make([]TickerData, 0)
|
||||
sectorChange := 0.0
|
||||
|
||||
for _, t := range sec.tickers {
|
||||
change := (randomFloat(-5, 5))
|
||||
price := randomFloat(50, 500)
|
||||
marketCap := randomFloat(50e9, 3000e9)
|
||||
volume := randomFloat(1e6, 100e6)
|
||||
|
||||
ticker := TickerData{
|
||||
Symbol: t.symbol,
|
||||
Name: t.name,
|
||||
Price: price,
|
||||
Change: price * change / 100,
|
||||
ChangePercent: change,
|
||||
Volume: volume,
|
||||
MarketCap: marketCap,
|
||||
Sector: sec.name,
|
||||
Color: getColorForChange(change),
|
||||
Size: math.Log10(marketCap) * 10,
|
||||
}
|
||||
|
||||
sectorTickers = append(sectorTickers, ticker)
|
||||
sectorChange += change
|
||||
}
|
||||
|
||||
if len(sectorTickers) > 0 {
|
||||
sectorChange /= float64(len(sectorTickers))
|
||||
}
|
||||
|
||||
sort.Slice(sectorTickers, func(i, j int) bool {
|
||||
return sectorTickers[i].ChangePercent > sectorTickers[j].ChangePercent
|
||||
})
|
||||
|
||||
var topGainers, topLosers []TickerData
|
||||
if len(sectorTickers) >= 2 {
|
||||
topGainers = sectorTickers[:2]
|
||||
topLosers = sectorTickers[len(sectorTickers)-2:]
|
||||
}
|
||||
|
||||
sectorMarketCap := 0.0
|
||||
sectorVolume := 0.0
|
||||
for _, t := range sectorTickers {
|
||||
sectorMarketCap += t.MarketCap
|
||||
sectorVolume += t.Volume
|
||||
}
|
||||
|
||||
sector := Sector{
|
||||
ID: strings.ToLower(strings.ReplaceAll(sec.name, " ", "_")),
|
||||
Name: sec.name,
|
||||
Change: sectorChange,
|
||||
MarketCap: sectorMarketCap,
|
||||
Volume: sectorVolume,
|
||||
TickerCount: len(sectorTickers),
|
||||
TopGainers: topGainers,
|
||||
TopLosers: topLosers,
|
||||
Color: getColorForChange(sectorChange),
|
||||
Weight: sectorMarketCap,
|
||||
}
|
||||
|
||||
allSectors = append(allSectors, sector)
|
||||
allTickers = append(allTickers, sectorTickers...)
|
||||
}
|
||||
|
||||
sort.Slice(allTickers, func(i, j int) bool {
|
||||
return allTickers[i].MarketCap > allTickers[j].MarketCap
|
||||
})
|
||||
|
||||
return &MarketHeatmap{
|
||||
ID: market,
|
||||
Title: getMarketTitle(market),
|
||||
Type: HeatmapTreemap,
|
||||
Market: market,
|
||||
Sectors: allSectors,
|
||||
Tickers: allTickers,
|
||||
Summary: *s.calculateSummaryPtr(allTickers),
|
||||
UpdatedAt: time.Now(),
|
||||
Colorscale: DefaultColorscale,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeatmapService) calculateSummary(tickers []TickerData) MarketSummary {
|
||||
return *s.calculateSummaryPtr(tickers)
|
||||
}
|
||||
|
||||
func (s *HeatmapService) calculateSummaryPtr(tickers []TickerData) *MarketSummary {
|
||||
summary := &MarketSummary{}
|
||||
|
||||
var totalChange float64
|
||||
var topGainer, topLoser, mostActive *TickerData
|
||||
|
||||
for i := range tickers {
|
||||
t := &tickers[i]
|
||||
summary.TotalMarketCap += t.MarketCap
|
||||
summary.TotalVolume += t.Volume
|
||||
totalChange += t.ChangePercent
|
||||
|
||||
if t.ChangePercent > 0 {
|
||||
summary.AdvancingCount++
|
||||
} else if t.ChangePercent < 0 {
|
||||
summary.DecliningCount++
|
||||
} else {
|
||||
summary.UnchangedCount++
|
||||
}
|
||||
|
||||
if topGainer == nil || t.ChangePercent > topGainer.ChangePercent {
|
||||
topGainer = t
|
||||
}
|
||||
if topLoser == nil || t.ChangePercent < topLoser.ChangePercent {
|
||||
topLoser = t
|
||||
}
|
||||
if mostActive == nil || t.Volume > mostActive.Volume {
|
||||
mostActive = t
|
||||
}
|
||||
}
|
||||
|
||||
if len(tickers) > 0 {
|
||||
summary.AverageChange = totalChange / float64(len(tickers))
|
||||
}
|
||||
|
||||
summary.TopGainer = topGainer
|
||||
summary.TopLoser = topLoser
|
||||
summary.MostActive = mostActive
|
||||
|
||||
if summary.AverageChange > 1 {
|
||||
summary.MarketSentiment = "bullish"
|
||||
} else if summary.AverageChange < -1 {
|
||||
summary.MarketSentiment = "bearish"
|
||||
} else {
|
||||
summary.MarketSentiment = "neutral"
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s *HeatmapService) GenerateTreemapData(heatmap *MarketHeatmap) interface{} {
|
||||
children := make([]map[string]interface{}, 0)
|
||||
|
||||
for _, sector := range heatmap.Sectors {
|
||||
sectorChildren := make([]map[string]interface{}, 0)
|
||||
|
||||
for _, ticker := range heatmap.Tickers {
|
||||
if ticker.Sector == sector.Name {
|
||||
sectorChildren = append(sectorChildren, map[string]interface{}{
|
||||
"name": ticker.Symbol,
|
||||
"value": ticker.MarketCap,
|
||||
"change": ticker.ChangePercent,
|
||||
"color": ticker.Color,
|
||||
"data": ticker,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
children = append(children, map[string]interface{}{
|
||||
"name": sector.Name,
|
||||
"children": sectorChildren,
|
||||
"change": sector.Change,
|
||||
"color": sector.Color,
|
||||
})
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"name": heatmap.Market,
|
||||
"children": children,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeatmapService) GenerateGridData(heatmap *MarketHeatmap, rows, cols int) [][]TickerData {
|
||||
grid := make([][]TickerData, rows)
|
||||
for i := range grid {
|
||||
grid[i] = make([]TickerData, cols)
|
||||
}
|
||||
|
||||
idx := 0
|
||||
for i := 0; i < rows && idx < len(heatmap.Tickers); i++ {
|
||||
for j := 0; j < cols && idx < len(heatmap.Tickers); j++ {
|
||||
grid[i][j] = heatmap.Tickers[idx]
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
return grid
|
||||
}
|
||||
|
||||
func (s *HeatmapService) GetTopMovers(ctx context.Context, market string, count int) (*TopMovers, error) {
|
||||
heatmap, err := s.GetMarketHeatmap(ctx, market, "1d")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tickers := make([]TickerData, len(heatmap.Tickers))
|
||||
copy(tickers, heatmap.Tickers)
|
||||
|
||||
sort.Slice(tickers, func(i, j int) bool {
|
||||
return tickers[i].ChangePercent > tickers[j].ChangePercent
|
||||
})
|
||||
|
||||
gainers := tickers
|
||||
if len(gainers) > count {
|
||||
gainers = gainers[:count]
|
||||
}
|
||||
|
||||
sort.Slice(tickers, func(i, j int) bool {
|
||||
return tickers[i].ChangePercent < tickers[j].ChangePercent
|
||||
})
|
||||
|
||||
losers := tickers
|
||||
if len(losers) > count {
|
||||
losers = losers[:count]
|
||||
}
|
||||
|
||||
sort.Slice(tickers, func(i, j int) bool {
|
||||
return tickers[i].Volume > tickers[j].Volume
|
||||
})
|
||||
|
||||
active := tickers
|
||||
if len(active) > count {
|
||||
active = active[:count]
|
||||
}
|
||||
|
||||
return &TopMovers{
|
||||
Gainers: gainers,
|
||||
Losers: losers,
|
||||
MostActive: active,
|
||||
UpdatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TopMovers struct {
|
||||
Gainers []TickerData `json:"gainers"`
|
||||
Losers []TickerData `json:"losers"`
|
||||
MostActive []TickerData `json:"mostActive"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func getColorForChange(change float64) string {
|
||||
if change >= 5 {
|
||||
return "#22c55e"
|
||||
} else if change >= 3 {
|
||||
return "#4ade80"
|
||||
} else if change >= 1 {
|
||||
return "#86efac"
|
||||
} else if change >= 0 {
|
||||
return "#bbf7d0"
|
||||
} else if change >= -1 {
|
||||
return "#fecaca"
|
||||
} else if change >= -3 {
|
||||
return "#fca5a5"
|
||||
} else if change >= -5 {
|
||||
return "#f87171"
|
||||
}
|
||||
return "#ef4444"
|
||||
}
|
||||
|
||||
func getMarketTitle(market string) string {
|
||||
titles := map[string]string{
|
||||
"sp500": "S&P 500",
|
||||
"nasdaq": "NASDAQ",
|
||||
"dow": "Dow Jones",
|
||||
"moex": "MOEX",
|
||||
"crypto": "Cryptocurrency",
|
||||
"forex": "Forex",
|
||||
"commodities": "Commodities",
|
||||
}
|
||||
if title, ok := titles[strings.ToLower(market)]; ok {
|
||||
return title
|
||||
}
|
||||
return market
|
||||
}
|
||||
|
||||
var rng uint64 = uint64(time.Now().UnixNano())
|
||||
|
||||
func randomFloat(min, max float64) float64 {
|
||||
rng ^= rng << 13
|
||||
rng ^= rng >> 17
|
||||
rng ^= rng << 5
|
||||
f := float64(rng) / float64(1<<64)
|
||||
return min + f*(max-min)
|
||||
}
|
||||
|
||||
func (h *MarketHeatmap) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(h)
|
||||
}
|
||||
759
backend/internal/labs/generator.go
Normal file
759
backend/internal/labs/generator.go
Normal file
@@ -0,0 +1,759 @@
|
||||
package labs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Generator struct {
|
||||
llm llm.Client
|
||||
}
|
||||
|
||||
func NewGenerator(llmClient llm.Client) *Generator {
|
||||
return &Generator{llm: llmClient}
|
||||
}
|
||||
|
||||
type GenerateOptions struct {
|
||||
Query string
|
||||
Data interface{}
|
||||
PreferredTypes []VisualizationType
|
||||
Theme string
|
||||
Locale string
|
||||
MaxVisualizations int
|
||||
}
|
||||
|
||||
func (g *Generator) GenerateReport(ctx context.Context, opts GenerateOptions) (*Report, error) {
|
||||
analysisPrompt := fmt.Sprintf(`Analyze this data and query to determine the best visualizations.
|
||||
|
||||
Query: %s
|
||||
|
||||
Data: %v
|
||||
|
||||
Determine:
|
||||
1. What visualizations would best represent this data?
|
||||
2. How should the data be structured for each visualization?
|
||||
3. What insights can be highlighted?
|
||||
|
||||
Respond in JSON format:
|
||||
{
|
||||
"title": "Report title",
|
||||
"sections": [
|
||||
{
|
||||
"title": "Section title",
|
||||
"visualizations": [
|
||||
{
|
||||
"type": "chart_type",
|
||||
"title": "Viz title",
|
||||
"dataMapping": { "how to map the data" },
|
||||
"insight": "Key insight"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Available visualization types: bar_chart, line_chart, pie_chart, donut_chart, table, stat_cards, kpi, comparison, timeline, progress, heatmap, code_block, markdown, collapsible, tabs, accordion`, opts.Query, opts.Data)
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: analysisPrompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var analysis struct {
|
||||
Title string `json:"title"`
|
||||
Sections []struct {
|
||||
Title string `json:"title"`
|
||||
Visualizations []struct {
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
DataMapping map[string]interface{} `json:"dataMapping"`
|
||||
Insight string `json:"insight"`
|
||||
} `json:"visualizations"`
|
||||
} `json:"sections"`
|
||||
}
|
||||
|
||||
jsonStr := extractJSON(result)
|
||||
if err := json.Unmarshal([]byte(jsonStr), &analysis); err != nil {
|
||||
return g.createDefaultReport(opts)
|
||||
}
|
||||
|
||||
report := &Report{
|
||||
ID: uuid.New().String(),
|
||||
Title: analysis.Title,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Theme: opts.Theme,
|
||||
Sections: make([]ReportSection, 0),
|
||||
}
|
||||
|
||||
for _, sec := range analysis.Sections {
|
||||
section := ReportSection{
|
||||
ID: uuid.New().String(),
|
||||
Title: sec.Title,
|
||||
Visualizations: make([]Visualization, 0),
|
||||
}
|
||||
|
||||
for _, viz := range sec.Visualizations {
|
||||
visualization := g.createVisualization(VisualizationType(viz.Type), viz.Title, opts.Data, viz.DataMapping)
|
||||
if visualization != nil {
|
||||
section.Visualizations = append(section.Visualizations, *visualization)
|
||||
}
|
||||
}
|
||||
|
||||
if len(section.Visualizations) > 0 {
|
||||
report.Sections = append(report.Sections, section)
|
||||
}
|
||||
}
|
||||
|
||||
return report, nil
|
||||
}
|
||||
|
||||
func (g *Generator) createDefaultReport(opts GenerateOptions) (*Report, error) {
|
||||
report := &Report{
|
||||
ID: uuid.New().String(),
|
||||
Title: "Анализ данных",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Sections: []ReportSection{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
Title: "Обзор",
|
||||
Visualizations: []Visualization{
|
||||
g.CreateMarkdown("", formatDataAsMarkdown(opts.Data)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
|
||||
func (g *Generator) createVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
switch vizType {
|
||||
case VizBarChart, VizLineChart, VizAreaChart:
|
||||
return g.createChartVisualization(vizType, title, data, mapping)
|
||||
case VizPieChart, VizDonutChart:
|
||||
return g.createPieVisualization(vizType, title, data, mapping)
|
||||
case VizTable:
|
||||
return g.createTableVisualization(title, data, mapping)
|
||||
case VizStatCards:
|
||||
return g.createStatCardsVisualization(title, data, mapping)
|
||||
case VizKPI:
|
||||
return g.createKPIVisualization(title, data, mapping)
|
||||
case VizTimeline:
|
||||
return g.createTimelineVisualization(title, data, mapping)
|
||||
case VizComparison:
|
||||
return g.createComparisonVisualization(title, data, mapping)
|
||||
case VizProgress:
|
||||
return g.createProgressVisualization(title, data, mapping)
|
||||
case VizMarkdown:
|
||||
content := extractStringFromData(data, mapping, "content")
|
||||
viz := g.CreateMarkdown(title, content)
|
||||
return &viz
|
||||
default:
|
||||
viz := g.CreateMarkdown(title, formatDataAsMarkdown(data))
|
||||
return &viz
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createChartVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
chartData := &ChartData{
|
||||
Labels: make([]string, 0),
|
||||
Datasets: make([]ChartDataset, 0),
|
||||
}
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
labels := make([]string, 0)
|
||||
values := make([]float64, 0)
|
||||
|
||||
for k, v := range dataMap {
|
||||
labels = append(labels, k)
|
||||
values = append(values, toFloat64(v))
|
||||
}
|
||||
|
||||
chartData.Labels = labels
|
||||
chartData.Datasets = append(chartData.Datasets, ChartDataset{
|
||||
Label: title,
|
||||
Data: values,
|
||||
})
|
||||
}
|
||||
|
||||
if dataSlice, ok := data.([]interface{}); ok {
|
||||
for _, item := range dataSlice {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if label, ok := itemMap["label"].(string); ok {
|
||||
chartData.Labels = append(chartData.Labels, label)
|
||||
}
|
||||
if value, ok := itemMap["value"]; ok {
|
||||
if len(chartData.Datasets) == 0 {
|
||||
chartData.Datasets = append(chartData.Datasets, ChartDataset{Label: title, Data: []float64{}})
|
||||
}
|
||||
chartData.Datasets[0].Data = append(chartData.Datasets[0].Data, toFloat64(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: vizType,
|
||||
Title: title,
|
||||
Data: chartData,
|
||||
Config: VisualizationConfig{
|
||||
ShowLegend: true,
|
||||
ShowTooltip: true,
|
||||
ShowGrid: true,
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createPieVisualization(vizType VisualizationType, title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
chartData := &ChartData{
|
||||
Labels: make([]string, 0),
|
||||
Datasets: make([]ChartDataset, 0),
|
||||
}
|
||||
|
||||
dataset := ChartDataset{Label: title, Data: []float64{}}
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
for k, v := range dataMap {
|
||||
chartData.Labels = append(chartData.Labels, k)
|
||||
dataset.Data = append(dataset.Data, toFloat64(v))
|
||||
}
|
||||
}
|
||||
|
||||
chartData.Datasets = append(chartData.Datasets, dataset)
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: vizType,
|
||||
Title: title,
|
||||
Data: chartData,
|
||||
Config: VisualizationConfig{
|
||||
ShowLegend: true,
|
||||
ShowTooltip: true,
|
||||
ShowValues: true,
|
||||
Animated: true,
|
||||
},
|
||||
Style: VisualizationStyle{
|
||||
Height: "300px",
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createTableVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
tableData := &TableData{
|
||||
Columns: make([]TableColumn, 0),
|
||||
Rows: make([]TableRow, 0),
|
||||
}
|
||||
|
||||
if dataSlice, ok := data.([]interface{}); ok && len(dataSlice) > 0 {
|
||||
if firstRow, ok := dataSlice[0].(map[string]interface{}); ok {
|
||||
for key := range firstRow {
|
||||
tableData.Columns = append(tableData.Columns, TableColumn{
|
||||
Key: key,
|
||||
Label: formatColumnLabel(key),
|
||||
Sortable: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range dataSlice {
|
||||
if rowMap, ok := item.(map[string]interface{}); ok {
|
||||
tableData.Rows = append(tableData.Rows, TableRow(rowMap))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
tableData.Columns = []TableColumn{
|
||||
{Key: "key", Label: "Параметр", Sortable: true},
|
||||
{Key: "value", Label: "Значение", Sortable: true},
|
||||
}
|
||||
|
||||
for k, v := range dataMap {
|
||||
tableData.Rows = append(tableData.Rows, TableRow{
|
||||
"key": k,
|
||||
"value": v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
tableData.Summary = &TableSummary{
|
||||
TotalRows: len(tableData.Rows),
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizTable,
|
||||
Title: title,
|
||||
Data: tableData,
|
||||
Config: VisualizationConfig{
|
||||
Sortable: true,
|
||||
Searchable: true,
|
||||
Paginated: len(tableData.Rows) > 10,
|
||||
PageSize: 10,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createStatCardsVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
cardsData := &StatCardsData{
|
||||
Cards: make([]StatCard, 0),
|
||||
}
|
||||
|
||||
colors := []string{"#3B82F6", "#10B981", "#F59E0B", "#EF4444", "#8B5CF6", "#EC4899"}
|
||||
colorIdx := 0
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
for k, v := range dataMap {
|
||||
card := StatCard{
|
||||
ID: uuid.New().String(),
|
||||
Title: formatColumnLabel(k),
|
||||
Value: v,
|
||||
Color: colors[colorIdx%len(colors)],
|
||||
}
|
||||
cardsData.Cards = append(cardsData.Cards, card)
|
||||
colorIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizStatCards,
|
||||
Title: title,
|
||||
Data: cardsData,
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createKPIVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
kpiData := &KPIData{
|
||||
Value: data,
|
||||
}
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if v, ok := dataMap["value"]; ok {
|
||||
kpiData.Value = v
|
||||
}
|
||||
if v, ok := dataMap["change"].(float64); ok {
|
||||
kpiData.Change = v
|
||||
if v >= 0 {
|
||||
kpiData.ChangeType = "increase"
|
||||
} else {
|
||||
kpiData.ChangeType = "decrease"
|
||||
}
|
||||
}
|
||||
if v, ok := dataMap["target"]; ok {
|
||||
kpiData.Target = v
|
||||
}
|
||||
if v, ok := dataMap["unit"].(string); ok {
|
||||
kpiData.Unit = v
|
||||
}
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizKPI,
|
||||
Title: title,
|
||||
Data: kpiData,
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
ShowValues: true,
|
||||
},
|
||||
Style: VisualizationStyle{
|
||||
MinHeight: "150px",
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createTimelineVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
timelineData := &TimelineData{
|
||||
Events: make([]TimelineEvent, 0),
|
||||
}
|
||||
|
||||
if dataSlice, ok := data.([]interface{}); ok {
|
||||
for _, item := range dataSlice {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
event := TimelineEvent{
|
||||
ID: uuid.New().String(),
|
||||
}
|
||||
|
||||
if v, ok := itemMap["date"].(string); ok {
|
||||
event.Date, _ = time.Parse(time.RFC3339, v)
|
||||
}
|
||||
if v, ok := itemMap["title"].(string); ok {
|
||||
event.Title = v
|
||||
}
|
||||
if v, ok := itemMap["description"].(string); ok {
|
||||
event.Description = v
|
||||
}
|
||||
|
||||
timelineData.Events = append(timelineData.Events, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(timelineData.Events, func(i, j int) bool {
|
||||
return timelineData.Events[i].Date.Before(timelineData.Events[j].Date)
|
||||
})
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizTimeline,
|
||||
Title: title,
|
||||
Data: timelineData,
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createComparisonVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
compData := &ComparisonData{
|
||||
Items: make([]ComparisonItem, 0),
|
||||
Categories: make([]string, 0),
|
||||
}
|
||||
|
||||
if dataSlice, ok := data.([]interface{}); ok && len(dataSlice) > 0 {
|
||||
if firstItem, ok := dataSlice[0].(map[string]interface{}); ok {
|
||||
for k := range firstItem {
|
||||
if k != "name" && k != "id" && k != "image" {
|
||||
compData.Categories = append(compData.Categories, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range dataSlice {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
compItem := ComparisonItem{
|
||||
ID: uuid.New().String(),
|
||||
Values: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if v, ok := itemMap["name"].(string); ok {
|
||||
compItem.Name = v
|
||||
}
|
||||
if v, ok := itemMap["image"].(string); ok {
|
||||
compItem.Image = v
|
||||
}
|
||||
|
||||
for _, cat := range compData.Categories {
|
||||
if v, ok := itemMap[cat]; ok {
|
||||
compItem.Values[cat] = v
|
||||
}
|
||||
}
|
||||
|
||||
compData.Items = append(compData.Items, compItem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizComparison,
|
||||
Title: title,
|
||||
Data: compData,
|
||||
Config: VisualizationConfig{
|
||||
ShowLabels: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) createProgressVisualization(title string, data interface{}, mapping map[string]interface{}) *Visualization {
|
||||
progressData := &ProgressData{
|
||||
Current: 0,
|
||||
Total: 100,
|
||||
ShowValue: true,
|
||||
Animated: true,
|
||||
}
|
||||
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if v, ok := dataMap["current"]; ok {
|
||||
progressData.Current = toFloat64(v)
|
||||
}
|
||||
if v, ok := dataMap["total"]; ok {
|
||||
progressData.Total = toFloat64(v)
|
||||
}
|
||||
if v, ok := dataMap["label"].(string); ok {
|
||||
progressData.Label = v
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := data.(float64); ok {
|
||||
progressData.Current = v
|
||||
}
|
||||
|
||||
return &Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizProgress,
|
||||
Title: title,
|
||||
Data: progressData,
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
ShowValues: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateBarChart(title string, labels []string, values []float64) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizBarChart,
|
||||
Title: title,
|
||||
Data: &ChartData{
|
||||
Labels: labels,
|
||||
Datasets: []ChartDataset{
|
||||
{Label: title, Data: values},
|
||||
},
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
ShowLegend: true,
|
||||
ShowTooltip: true,
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateLineChart(title string, labels []string, datasets []ChartDataset) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizLineChart,
|
||||
Title: title,
|
||||
Data: &ChartData{
|
||||
Labels: labels,
|
||||
Datasets: datasets,
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
ShowLegend: true,
|
||||
ShowTooltip: true,
|
||||
ShowGrid: true,
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreatePieChart(title string, labels []string, values []float64) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizPieChart,
|
||||
Title: title,
|
||||
Data: &ChartData{
|
||||
Labels: labels,
|
||||
Datasets: []ChartDataset{
|
||||
{Label: title, Data: values},
|
||||
},
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
ShowLegend: true,
|
||||
ShowTooltip: true,
|
||||
ShowValues: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateTable(title string, columns []TableColumn, rows []TableRow) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizTable,
|
||||
Title: title,
|
||||
Data: &TableData{
|
||||
Columns: columns,
|
||||
Rows: rows,
|
||||
Summary: &TableSummary{TotalRows: len(rows)},
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
Sortable: true,
|
||||
Searchable: true,
|
||||
Paginated: len(rows) > 10,
|
||||
PageSize: 10,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateStatCards(title string, cards []StatCard) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizStatCards,
|
||||
Title: title,
|
||||
Data: &StatCardsData{Cards: cards},
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateKPI(title string, value interface{}, change float64, unit string) Visualization {
|
||||
changeType := "neutral"
|
||||
if change > 0 {
|
||||
changeType = "increase"
|
||||
} else if change < 0 {
|
||||
changeType = "decrease"
|
||||
}
|
||||
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizKPI,
|
||||
Title: title,
|
||||
Data: &KPIData{
|
||||
Value: value,
|
||||
Change: change,
|
||||
ChangeType: changeType,
|
||||
Unit: unit,
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateMarkdown(title string, content string) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizMarkdown,
|
||||
Title: title,
|
||||
Data: &MarkdownData{Content: content},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateCodeBlock(title, code, language string) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizCodeBlock,
|
||||
Title: title,
|
||||
Data: &CodeBlockData{
|
||||
Code: code,
|
||||
Language: language,
|
||||
ShowLineNum: true,
|
||||
Copyable: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateTabs(title string, tabs []TabItem) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizTabs,
|
||||
Title: title,
|
||||
Data: &TabsData{Tabs: tabs},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateAccordion(title string, items []AccordionItem) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizAccordion,
|
||||
Title: title,
|
||||
Data: &AccordionData{Items: items},
|
||||
Config: VisualizationConfig{
|
||||
Animated: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Generator) CreateHeatmap(title string, xLabels, yLabels []string, values [][]float64) Visualization {
|
||||
return Visualization{
|
||||
ID: uuid.New().String(),
|
||||
Type: VizHeatmap,
|
||||
Title: title,
|
||||
Data: &HeatmapData{
|
||||
XLabels: xLabels,
|
||||
YLabels: yLabels,
|
||||
Values: values,
|
||||
},
|
||||
Config: VisualizationConfig{
|
||||
ShowTooltip: true,
|
||||
ShowLabels: true,
|
||||
},
|
||||
Responsive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSON(text string) string {
|
||||
re := regexp.MustCompile(`(?s)\{.*\}`)
|
||||
match := re.FindString(text)
|
||||
if match != "" {
|
||||
return match
|
||||
}
|
||||
return "{}"
|
||||
}
|
||||
|
||||
func toFloat64(v interface{}) float64 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val
|
||||
case float32:
|
||||
return float64(val)
|
||||
case int:
|
||||
return float64(val)
|
||||
case int64:
|
||||
return float64(val)
|
||||
case string:
|
||||
f, _ := strconv.ParseFloat(val, 64)
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func formatColumnLabel(key string) string {
|
||||
key = strings.ReplaceAll(key, "_", " ")
|
||||
key = strings.ReplaceAll(key, "-", " ")
|
||||
|
||||
words := strings.Fields(key)
|
||||
for i, word := range words {
|
||||
if len(word) > 0 {
|
||||
words[i] = strings.ToUpper(string(word[0])) + strings.ToLower(word[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
func extractStringFromData(data interface{}, mapping map[string]interface{}, key string) string {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if v, ok := dataMap[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", data)
|
||||
}
|
||||
|
||||
func formatDataAsMarkdown(data interface{}) string {
|
||||
jsonBytes, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", data)
|
||||
}
|
||||
return "```json\n" + string(jsonBytes) + "\n```"
|
||||
}
|
||||
335
backend/internal/labs/types.go
Normal file
335
backend/internal/labs/types.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package labs
|
||||
|
||||
import "time"
|
||||
|
||||
type VisualizationType string
|
||||
|
||||
const (
|
||||
VizBarChart VisualizationType = "bar_chart"
|
||||
VizLineChart VisualizationType = "line_chart"
|
||||
VizPieChart VisualizationType = "pie_chart"
|
||||
VizDonutChart VisualizationType = "donut_chart"
|
||||
VizAreaChart VisualizationType = "area_chart"
|
||||
VizScatterPlot VisualizationType = "scatter_plot"
|
||||
VizHeatmap VisualizationType = "heatmap"
|
||||
VizTreemap VisualizationType = "treemap"
|
||||
VizGauge VisualizationType = "gauge"
|
||||
VizRadar VisualizationType = "radar"
|
||||
VizSankey VisualizationType = "sankey"
|
||||
VizTable VisualizationType = "table"
|
||||
VizTimeline VisualizationType = "timeline"
|
||||
VizKPI VisualizationType = "kpi"
|
||||
VizProgress VisualizationType = "progress"
|
||||
VizComparison VisualizationType = "comparison"
|
||||
VizStatCards VisualizationType = "stat_cards"
|
||||
VizMap VisualizationType = "map"
|
||||
VizFlowChart VisualizationType = "flow_chart"
|
||||
VizOrgChart VisualizationType = "org_chart"
|
||||
VizCodeBlock VisualizationType = "code_block"
|
||||
VizMarkdown VisualizationType = "markdown"
|
||||
VizCollapsible VisualizationType = "collapsible"
|
||||
VizTabs VisualizationType = "tabs"
|
||||
VizAccordion VisualizationType = "accordion"
|
||||
VizStepper VisualizationType = "stepper"
|
||||
VizForm VisualizationType = "form"
|
||||
)
|
||||
|
||||
type Visualization struct {
|
||||
ID string `json:"id"`
|
||||
Type VisualizationType `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Data interface{} `json:"data"`
|
||||
Config VisualizationConfig `json:"config,omitempty"`
|
||||
Style VisualizationStyle `json:"style,omitempty"`
|
||||
Actions []VisualizationAction `json:"actions,omitempty"`
|
||||
Responsive bool `json:"responsive"`
|
||||
}
|
||||
|
||||
type VisualizationConfig struct {
|
||||
ShowLegend bool `json:"showLegend,omitempty"`
|
||||
ShowGrid bool `json:"showGrid,omitempty"`
|
||||
ShowTooltip bool `json:"showTooltip,omitempty"`
|
||||
ShowLabels bool `json:"showLabels,omitempty"`
|
||||
ShowValues bool `json:"showValues,omitempty"`
|
||||
Animated bool `json:"animated,omitempty"`
|
||||
Stacked bool `json:"stacked,omitempty"`
|
||||
Horizontal bool `json:"horizontal,omitempty"`
|
||||
Sortable bool `json:"sortable,omitempty"`
|
||||
Filterable bool `json:"filterable,omitempty"`
|
||||
Searchable bool `json:"searchable,omitempty"`
|
||||
Paginated bool `json:"paginated,omitempty"`
|
||||
PageSize int `json:"pageSize,omitempty"`
|
||||
Expandable bool `json:"expandable,omitempty"`
|
||||
DefaultExpanded bool `json:"defaultExpanded,omitempty"`
|
||||
XAxisLabel string `json:"xAxisLabel,omitempty"`
|
||||
YAxisLabel string `json:"yAxisLabel,omitempty"`
|
||||
Colors []string `json:"colors,omitempty"`
|
||||
DateFormat string `json:"dateFormat,omitempty"`
|
||||
NumberFormat string `json:"numberFormat,omitempty"`
|
||||
CurrencySymbol string `json:"currencySymbol,omitempty"`
|
||||
}
|
||||
|
||||
type VisualizationStyle struct {
|
||||
Width string `json:"width,omitempty"`
|
||||
Height string `json:"height,omitempty"`
|
||||
MinHeight string `json:"minHeight,omitempty"`
|
||||
MaxHeight string `json:"maxHeight,omitempty"`
|
||||
Padding string `json:"padding,omitempty"`
|
||||
Margin string `json:"margin,omitempty"`
|
||||
BorderRadius string `json:"borderRadius,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Shadow string `json:"shadow,omitempty"`
|
||||
FontFamily string `json:"fontFamily,omitempty"`
|
||||
FontSize string `json:"fontSize,omitempty"`
|
||||
TextColor string `json:"textColor,omitempty"`
|
||||
AccentColor string `json:"accentColor,omitempty"`
|
||||
GridColor string `json:"gridColor,omitempty"`
|
||||
}
|
||||
|
||||
type VisualizationAction struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Handler string `json:"handler,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type ChartData struct {
|
||||
Labels []string `json:"labels"`
|
||||
Datasets []ChartDataset `json:"datasets"`
|
||||
}
|
||||
|
||||
type ChartDataset struct {
|
||||
Label string `json:"label"`
|
||||
Data []float64 `json:"data"`
|
||||
BackgroundColor string `json:"backgroundColor,omitempty"`
|
||||
BorderColor string `json:"borderColor,omitempty"`
|
||||
Fill bool `json:"fill,omitempty"`
|
||||
}
|
||||
|
||||
type TableData struct {
|
||||
Columns []TableColumn `json:"columns"`
|
||||
Rows []TableRow `json:"rows"`
|
||||
Summary *TableSummary `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type TableColumn struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Width string `json:"width,omitempty"`
|
||||
Sortable bool `json:"sortable,omitempty"`
|
||||
Align string `json:"align,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Highlight bool `json:"highlight,omitempty"`
|
||||
}
|
||||
|
||||
type TableRow map[string]interface{}
|
||||
|
||||
type TableSummary struct {
|
||||
TotalRows int `json:"totalRows"`
|
||||
Aggregations map[string]interface{} `json:"aggregations,omitempty"`
|
||||
}
|
||||
|
||||
type TimelineData struct {
|
||||
Events []TimelineEvent `json:"events"`
|
||||
}
|
||||
|
||||
type TimelineEvent struct {
|
||||
ID string `json:"id"`
|
||||
Date time.Time `json:"date"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Link string `json:"link,omitempty"`
|
||||
}
|
||||
|
||||
type KPIData struct {
|
||||
Value interface{} `json:"value"`
|
||||
PrevValue interface{} `json:"prevValue,omitempty"`
|
||||
Change float64 `json:"change,omitempty"`
|
||||
ChangeType string `json:"changeType,omitempty"`
|
||||
Unit string `json:"unit,omitempty"`
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
Target interface{} `json:"target,omitempty"`
|
||||
Trend []float64 `json:"trend,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type StatCardsData struct {
|
||||
Cards []StatCard `json:"cards"`
|
||||
}
|
||||
|
||||
type StatCard struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Value interface{} `json:"value"`
|
||||
Change float64 `json:"change,omitempty"`
|
||||
ChangeLabel string `json:"changeLabel,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Sparkline []float64 `json:"sparkline,omitempty"`
|
||||
}
|
||||
|
||||
type ComparisonData struct {
|
||||
Items []ComparisonItem `json:"items"`
|
||||
Categories []string `json:"categories"`
|
||||
}
|
||||
|
||||
type ComparisonItem struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Values map[string]interface{} `json:"values"`
|
||||
}
|
||||
|
||||
type ProgressData struct {
|
||||
Current float64 `json:"current"`
|
||||
Total float64 `json:"total"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
ShowValue bool `json:"showValue,omitempty"`
|
||||
Animated bool `json:"animated,omitempty"`
|
||||
}
|
||||
|
||||
type HeatmapData struct {
|
||||
XLabels []string `json:"xLabels"`
|
||||
YLabels []string `json:"yLabels"`
|
||||
Values [][]float64 `json:"values"`
|
||||
Min float64 `json:"min,omitempty"`
|
||||
Max float64 `json:"max,omitempty"`
|
||||
}
|
||||
|
||||
type MapData struct {
|
||||
Center []float64 `json:"center"`
|
||||
Zoom int `json:"zoom"`
|
||||
Markers []MapMarker `json:"markers,omitempty"`
|
||||
Regions []MapRegion `json:"regions,omitempty"`
|
||||
}
|
||||
|
||||
type MapMarker struct {
|
||||
ID string `json:"id"`
|
||||
Position []float64 `json:"position"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Popup string `json:"popup,omitempty"`
|
||||
}
|
||||
|
||||
type MapRegion struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value float64 `json:"value"`
|
||||
Color string `json:"color,omitempty"`
|
||||
}
|
||||
|
||||
type CollapsibleData struct {
|
||||
Title string `json:"title"`
|
||||
Content interface{} `json:"content"`
|
||||
DefaultOpen bool `json:"defaultOpen,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Children []Visualization `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type TabsData struct {
|
||||
Tabs []TabItem `json:"tabs"`
|
||||
}
|
||||
|
||||
type TabItem struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Content interface{} `json:"content"`
|
||||
Children []Visualization `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type AccordionData struct {
|
||||
Items []AccordionItem `json:"items"`
|
||||
}
|
||||
|
||||
type AccordionItem struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content interface{} `json:"content"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Open bool `json:"open,omitempty"`
|
||||
}
|
||||
|
||||
type StepperData struct {
|
||||
Steps []StepperStep `json:"steps"`
|
||||
CurrentStep int `json:"currentStep"`
|
||||
Orientation string `json:"orientation,omitempty"`
|
||||
}
|
||||
|
||||
type StepperStep struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Content interface{} `json:"content,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
}
|
||||
|
||||
type FormData struct {
|
||||
Fields []FormField `json:"fields"`
|
||||
SubmitLabel string `json:"submitLabel,omitempty"`
|
||||
Layout string `json:"layout,omitempty"`
|
||||
}
|
||||
|
||||
type FormField struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
Placeholder string `json:"placeholder,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
Options []FormOption `json:"options,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
Validation string `json:"validation,omitempty"`
|
||||
}
|
||||
|
||||
type FormOption struct {
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
type CodeBlockData struct {
|
||||
Code string `json:"code"`
|
||||
Language string `json:"language"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Highlight []int `json:"highlight,omitempty"`
|
||||
ShowLineNum bool `json:"showLineNum,omitempty"`
|
||||
Copyable bool `json:"copyable,omitempty"`
|
||||
}
|
||||
|
||||
type MarkdownData struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Report struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Sections []ReportSection `json:"sections"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Author string `json:"author,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
Theme string `json:"theme,omitempty"`
|
||||
CustomCSS string `json:"customCss,omitempty"`
|
||||
}
|
||||
|
||||
type ReportSection struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Visualizations []Visualization `json:"visualizations"`
|
||||
Layout string `json:"layout,omitempty"`
|
||||
Columns int `json:"columns,omitempty"`
|
||||
}
|
||||
701
backend/internal/learning/stepper.go
Normal file
701
backend/internal/learning/stepper.go
Normal file
@@ -0,0 +1,701 @@
|
||||
package learning
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type LearningMode string
|
||||
|
||||
const (
|
||||
ModeExplain LearningMode = "explain"
|
||||
ModeGuided LearningMode = "guided"
|
||||
ModeInteractive LearningMode = "interactive"
|
||||
ModePractice LearningMode = "practice"
|
||||
ModeQuiz LearningMode = "quiz"
|
||||
)
|
||||
|
||||
type DifficultyLevel string
|
||||
|
||||
const (
|
||||
DifficultyBeginner DifficultyLevel = "beginner"
|
||||
DifficultyIntermediate DifficultyLevel = "intermediate"
|
||||
DifficultyAdvanced DifficultyLevel = "advanced"
|
||||
DifficultyExpert DifficultyLevel = "expert"
|
||||
)
|
||||
|
||||
type StepByStepLesson struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Topic string `json:"topic"`
|
||||
Difficulty DifficultyLevel `json:"difficulty"`
|
||||
Mode LearningMode `json:"mode"`
|
||||
Steps []LearningStep `json:"steps"`
|
||||
Prerequisites []string `json:"prerequisites,omitempty"`
|
||||
LearningGoals []string `json:"learningGoals"`
|
||||
EstimatedTime int `json:"estimatedTimeMinutes"`
|
||||
Progress LessonProgress `json:"progress"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type LearningStep struct {
|
||||
ID string `json:"id"`
|
||||
Number int `json:"number"`
|
||||
Title string `json:"title"`
|
||||
Type StepType `json:"type"`
|
||||
Content StepContent `json:"content"`
|
||||
Interaction *StepInteraction `json:"interaction,omitempty"`
|
||||
Hints []string `json:"hints,omitempty"`
|
||||
Examples []Example `json:"examples,omitempty"`
|
||||
Practice *PracticeExercise `json:"practice,omitempty"`
|
||||
Quiz *QuizQuestion `json:"quiz,omitempty"`
|
||||
Duration int `json:"durationSeconds,omitempty"`
|
||||
Status StepStatus `json:"status"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type StepType string
|
||||
|
||||
const (
|
||||
StepExplanation StepType = "explanation"
|
||||
StepVisualization StepType = "visualization"
|
||||
StepCode StepType = "code"
|
||||
StepInteractive StepType = "interactive"
|
||||
StepPractice StepType = "practice"
|
||||
StepQuiz StepType = "quiz"
|
||||
StepSummary StepType = "summary"
|
||||
StepCheckpoint StepType = "checkpoint"
|
||||
)
|
||||
|
||||
type StepStatus string
|
||||
|
||||
const (
|
||||
StatusLocked StepStatus = "locked"
|
||||
StatusAvailable StepStatus = "available"
|
||||
StatusInProgress StepStatus = "in_progress"
|
||||
StatusCompleted StepStatus = "completed"
|
||||
StatusSkipped StepStatus = "skipped"
|
||||
)
|
||||
|
||||
type StepContent struct {
|
||||
Text string `json:"text"`
|
||||
Markdown string `json:"markdown,omitempty"`
|
||||
HTML string `json:"html,omitempty"`
|
||||
Code *CodeContent `json:"code,omitempty"`
|
||||
Visualization *VisualizationContent `json:"visualization,omitempty"`
|
||||
Media *MediaContent `json:"media,omitempty"`
|
||||
Formula string `json:"formula,omitempty"`
|
||||
Highlights []TextHighlight `json:"highlights,omitempty"`
|
||||
}
|
||||
|
||||
type CodeContent struct {
|
||||
Language string `json:"language"`
|
||||
Code string `json:"code"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Runnable bool `json:"runnable"`
|
||||
Editable bool `json:"editable"`
|
||||
Highlights []int `json:"highlights,omitempty"`
|
||||
Annotations []CodeAnnotation `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
type CodeAnnotation struct {
|
||||
Line int `json:"line"`
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type VisualizationContent struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
Config map[string]interface{} `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type MediaContent struct {
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Caption string `json:"caption,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
type TextHighlight struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
Note string `json:"note,omitempty"`
|
||||
}
|
||||
|
||||
type StepInteraction struct {
|
||||
Type string `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Options []Option `json:"options,omitempty"`
|
||||
Validation *Validation `json:"validation,omitempty"`
|
||||
Feedback *Feedback `json:"feedback,omitempty"`
|
||||
}
|
||||
|
||||
type Option struct {
|
||||
ID string `json:"id"`
|
||||
Text string `json:"text"`
|
||||
IsCorrect bool `json:"isCorrect,omitempty"`
|
||||
Feedback string `json:"feedback,omitempty"`
|
||||
}
|
||||
|
||||
type Validation struct {
|
||||
Type string `json:"type"`
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
Expected string `json:"expected,omitempty"`
|
||||
Keywords []string `json:"keywords,omitempty"`
|
||||
}
|
||||
|
||||
type Feedback struct {
|
||||
Correct string `json:"correct"`
|
||||
Incorrect string `json:"incorrect"`
|
||||
Partial string `json:"partial,omitempty"`
|
||||
}
|
||||
|
||||
type Example struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Input string `json:"input,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
}
|
||||
|
||||
type PracticeExercise struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Instructions string `json:"instructions"`
|
||||
Starter string `json:"starter,omitempty"`
|
||||
Solution string `json:"solution,omitempty"`
|
||||
TestCases []TestCase `json:"testCases,omitempty"`
|
||||
Hints []string `json:"hints,omitempty"`
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
Input string `json:"input"`
|
||||
Expected string `json:"expected"`
|
||||
Hidden bool `json:"hidden,omitempty"`
|
||||
}
|
||||
|
||||
type QuizQuestion struct {
|
||||
Question string `json:"question"`
|
||||
Type string `json:"type"`
|
||||
Options []Option `json:"options,omitempty"`
|
||||
CorrectIndex []int `json:"correctIndex,omitempty"`
|
||||
Explanation string `json:"explanation"`
|
||||
Points int `json:"points"`
|
||||
}
|
||||
|
||||
type LessonProgress struct {
|
||||
CurrentStep int `json:"currentStep"`
|
||||
CompletedSteps []int `json:"completedSteps"`
|
||||
Score int `json:"score"`
|
||||
MaxScore int `json:"maxScore"`
|
||||
TimeSpent int `json:"timeSpentSeconds"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
LastAccessed time.Time `json:"lastAccessed"`
|
||||
Completed bool `json:"completed"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
}
|
||||
|
||||
type LearningGenerator struct {
|
||||
llm llm.Client
|
||||
}
|
||||
|
||||
func NewLearningGenerator(llmClient llm.Client) *LearningGenerator {
|
||||
return &LearningGenerator{llm: llmClient}
|
||||
}
|
||||
|
||||
type GenerateLessonOptions struct {
|
||||
Topic string
|
||||
Query string
|
||||
Difficulty DifficultyLevel
|
||||
Mode LearningMode
|
||||
MaxSteps int
|
||||
Locale string
|
||||
IncludeCode bool
|
||||
IncludeQuiz bool
|
||||
}
|
||||
|
||||
func (g *LearningGenerator) GenerateLesson(ctx context.Context, opts GenerateLessonOptions) (*StepByStepLesson, error) {
|
||||
if opts.MaxSteps == 0 {
|
||||
opts.MaxSteps = 10
|
||||
}
|
||||
if opts.Difficulty == "" {
|
||||
opts.Difficulty = DifficultyBeginner
|
||||
}
|
||||
if opts.Mode == "" {
|
||||
opts.Mode = ModeExplain
|
||||
}
|
||||
|
||||
langInstruction := ""
|
||||
if opts.Locale == "ru" {
|
||||
langInstruction = "Generate all content in Russian language."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Create a step-by-step educational lesson on the following topic.
|
||||
|
||||
Topic: %s
|
||||
Query: %s
|
||||
Difficulty: %s
|
||||
Mode: %s
|
||||
Max Steps: %d
|
||||
Include Code Examples: %v
|
||||
Include Quiz: %v
|
||||
%s
|
||||
|
||||
Generate a structured lesson with these requirements:
|
||||
1. Break down the concept into clear, digestible steps
|
||||
2. Each step should build on the previous one
|
||||
3. Include explanations, examples, and visualizations where helpful
|
||||
4. For code topics, include runnable code snippets
|
||||
5. Add practice exercises for interactive learning
|
||||
6. Include quiz questions to test understanding
|
||||
|
||||
Respond in this JSON format:
|
||||
{
|
||||
"title": "Lesson title",
|
||||
"description": "Brief description",
|
||||
"learningGoals": ["Goal 1", "Goal 2"],
|
||||
"estimatedTimeMinutes": 15,
|
||||
"steps": [
|
||||
{
|
||||
"title": "Step title",
|
||||
"type": "explanation|code|interactive|practice|quiz|summary",
|
||||
"content": {
|
||||
"text": "Main explanation",
|
||||
"markdown": "## Formatted content",
|
||||
"code": {"language": "python", "code": "example", "runnable": true},
|
||||
"formula": "optional LaTeX formula"
|
||||
},
|
||||
"hints": ["Hint 1"],
|
||||
"examples": [{"title": "Example", "description": "...", "code": "..."}],
|
||||
"quiz": {
|
||||
"question": "...",
|
||||
"type": "multiple_choice",
|
||||
"options": [{"id": "a", "text": "Option A", "isCorrect": false}],
|
||||
"explanation": "..."
|
||||
}
|
||||
}
|
||||
]
|
||||
}`, opts.Topic, opts.Query, opts.Difficulty, opts.Mode, opts.MaxSteps, opts.IncludeCode, opts.IncludeQuiz, langInstruction)
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonStr := extractJSON(result)
|
||||
|
||||
var parsed struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
LearningGoals []string `json:"learningGoals"`
|
||||
EstimatedTimeMinutes int `json:"estimatedTimeMinutes"`
|
||||
Steps []struct {
|
||||
Title string `json:"title"`
|
||||
Type string `json:"type"`
|
||||
Content struct {
|
||||
Text string `json:"text"`
|
||||
Markdown string `json:"markdown"`
|
||||
Code *struct {
|
||||
Language string `json:"language"`
|
||||
Code string `json:"code"`
|
||||
Runnable bool `json:"runnable"`
|
||||
} `json:"code"`
|
||||
Formula string `json:"formula"`
|
||||
} `json:"content"`
|
||||
Hints []string `json:"hints"`
|
||||
Examples []struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Code string `json:"code"`
|
||||
} `json:"examples"`
|
||||
Quiz *struct {
|
||||
Question string `json:"question"`
|
||||
Type string `json:"type"`
|
||||
Options []struct {
|
||||
ID string `json:"id"`
|
||||
Text string `json:"text"`
|
||||
IsCorrect bool `json:"isCorrect"`
|
||||
} `json:"options"`
|
||||
Explanation string `json:"explanation"`
|
||||
} `json:"quiz"`
|
||||
} `json:"steps"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
|
||||
return g.createDefaultLesson(opts)
|
||||
}
|
||||
|
||||
lesson := &StepByStepLesson{
|
||||
ID: uuid.New().String(),
|
||||
Title: parsed.Title,
|
||||
Description: parsed.Description,
|
||||
Topic: opts.Topic,
|
||||
Difficulty: opts.Difficulty,
|
||||
Mode: opts.Mode,
|
||||
LearningGoals: parsed.LearningGoals,
|
||||
EstimatedTime: parsed.EstimatedTimeMinutes,
|
||||
Steps: make([]LearningStep, 0),
|
||||
Progress: LessonProgress{
|
||||
CurrentStep: 0,
|
||||
CompletedSteps: []int{},
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i, s := range parsed.Steps {
|
||||
step := LearningStep{
|
||||
ID: uuid.New().String(),
|
||||
Number: i + 1,
|
||||
Title: s.Title,
|
||||
Type: StepType(s.Type),
|
||||
Content: StepContent{
|
||||
Text: s.Content.Text,
|
||||
Markdown: s.Content.Markdown,
|
||||
Formula: s.Content.Formula,
|
||||
},
|
||||
Hints: s.Hints,
|
||||
Status: StatusAvailable,
|
||||
}
|
||||
|
||||
if i > 0 {
|
||||
step.Status = StatusLocked
|
||||
}
|
||||
|
||||
if s.Content.Code != nil {
|
||||
step.Content.Code = &CodeContent{
|
||||
Language: s.Content.Code.Language,
|
||||
Code: s.Content.Code.Code,
|
||||
Runnable: s.Content.Code.Runnable,
|
||||
Editable: true,
|
||||
}
|
||||
}
|
||||
|
||||
for _, ex := range s.Examples {
|
||||
step.Examples = append(step.Examples, Example{
|
||||
Title: ex.Title,
|
||||
Description: ex.Description,
|
||||
Code: ex.Code,
|
||||
})
|
||||
}
|
||||
|
||||
if s.Quiz != nil {
|
||||
quiz := &QuizQuestion{
|
||||
Question: s.Quiz.Question,
|
||||
Type: s.Quiz.Type,
|
||||
Explanation: s.Quiz.Explanation,
|
||||
Points: 10,
|
||||
}
|
||||
for _, opt := range s.Quiz.Options {
|
||||
quiz.Options = append(quiz.Options, Option{
|
||||
ID: opt.ID,
|
||||
Text: opt.Text,
|
||||
IsCorrect: opt.IsCorrect,
|
||||
})
|
||||
}
|
||||
step.Quiz = quiz
|
||||
}
|
||||
|
||||
lesson.Steps = append(lesson.Steps, step)
|
||||
}
|
||||
|
||||
lesson.Progress.MaxScore = len(lesson.Steps) * 10
|
||||
|
||||
return lesson, nil
|
||||
}
|
||||
|
||||
func (g *LearningGenerator) createDefaultLesson(opts GenerateLessonOptions) (*StepByStepLesson, error) {
|
||||
return &StepByStepLesson{
|
||||
ID: uuid.New().String(),
|
||||
Title: fmt.Sprintf("Learn: %s", opts.Topic),
|
||||
Description: opts.Query,
|
||||
Topic: opts.Topic,
|
||||
Difficulty: opts.Difficulty,
|
||||
Mode: opts.Mode,
|
||||
LearningGoals: []string{"Understand the basics"},
|
||||
EstimatedTime: 10,
|
||||
Steps: []LearningStep{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
Number: 1,
|
||||
Title: "Introduction",
|
||||
Type: StepExplanation,
|
||||
Content: StepContent{
|
||||
Text: opts.Query,
|
||||
Markdown: fmt.Sprintf("# %s\n\n%s", opts.Topic, opts.Query),
|
||||
},
|
||||
Status: StatusAvailable,
|
||||
},
|
||||
},
|
||||
Progress: LessonProgress{
|
||||
CurrentStep: 0,
|
||||
CompletedSteps: []int{},
|
||||
MaxScore: 10,
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *LearningGenerator) GenerateExplanation(ctx context.Context, topic string, difficulty DifficultyLevel, locale string) (*LearningStep, error) {
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Respond in Russian."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Explain this topic step by step for a %s level learner.
|
||||
Topic: %s
|
||||
%s
|
||||
|
||||
Format your response with clear sections:
|
||||
1. Start with a simple definition
|
||||
2. Explain key concepts
|
||||
3. Provide a real-world analogy
|
||||
4. Give a concrete example
|
||||
|
||||
Use markdown formatting.`, difficulty, topic, langInstruction)
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LearningStep{
|
||||
ID: uuid.New().String(),
|
||||
Number: 1,
|
||||
Title: topic,
|
||||
Type: StepExplanation,
|
||||
Content: StepContent{
|
||||
Markdown: result,
|
||||
},
|
||||
Status: StatusAvailable,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *LearningGenerator) GenerateQuiz(ctx context.Context, topic string, numQuestions int, difficulty DifficultyLevel, locale string) ([]QuizQuestion, error) {
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Generate all questions and answers in Russian."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Generate %d multiple choice quiz questions about: %s
|
||||
Difficulty level: %s
|
||||
%s
|
||||
|
||||
Respond in JSON format:
|
||||
{
|
||||
"questions": [
|
||||
{
|
||||
"question": "Question text",
|
||||
"options": [
|
||||
{"id": "a", "text": "Option A", "isCorrect": false},
|
||||
{"id": "b", "text": "Option B", "isCorrect": true},
|
||||
{"id": "c", "text": "Option C", "isCorrect": false},
|
||||
{"id": "d", "text": "Option D", "isCorrect": false}
|
||||
],
|
||||
"explanation": "Why the correct answer is correct"
|
||||
}
|
||||
]
|
||||
}`, numQuestions, topic, difficulty, langInstruction)
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonStr := extractJSON(result)
|
||||
|
||||
var parsed struct {
|
||||
Questions []struct {
|
||||
Question string `json:"question"`
|
||||
Options []struct {
|
||||
ID string `json:"id"`
|
||||
Text string `json:"text"`
|
||||
IsCorrect bool `json:"isCorrect"`
|
||||
} `json:"options"`
|
||||
Explanation string `json:"explanation"`
|
||||
} `json:"questions"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
questions := make([]QuizQuestion, 0)
|
||||
for _, q := range parsed.Questions {
|
||||
quiz := QuizQuestion{
|
||||
Question: q.Question,
|
||||
Type: "multiple_choice",
|
||||
Explanation: q.Explanation,
|
||||
Points: 10,
|
||||
}
|
||||
for _, opt := range q.Options {
|
||||
quiz.Options = append(quiz.Options, Option{
|
||||
ID: opt.ID,
|
||||
Text: opt.Text,
|
||||
IsCorrect: opt.IsCorrect,
|
||||
})
|
||||
}
|
||||
questions = append(questions, quiz)
|
||||
}
|
||||
|
||||
return questions, nil
|
||||
}
|
||||
|
||||
func (g *LearningGenerator) GeneratePracticeExercise(ctx context.Context, topic, language string, difficulty DifficultyLevel, locale string) (*PracticeExercise, error) {
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Write instructions and explanations in Russian."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`Create a coding practice exercise for: %s
|
||||
Programming language: %s
|
||||
Difficulty: %s
|
||||
%s
|
||||
|
||||
Generate:
|
||||
1. A clear problem statement
|
||||
2. Step-by-step instructions
|
||||
3. Starter code template
|
||||
4. Solution code
|
||||
5. Test cases
|
||||
|
||||
Respond in JSON:
|
||||
{
|
||||
"prompt": "Problem statement",
|
||||
"instructions": "Step-by-step instructions",
|
||||
"starter": "// Starter code",
|
||||
"solution": "// Solution code",
|
||||
"testCases": [
|
||||
{"input": "test input", "expected": "expected output"}
|
||||
],
|
||||
"hints": ["Hint 1", "Hint 2"]
|
||||
}`, topic, language, difficulty, langInstruction)
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonStr := extractJSON(result)
|
||||
|
||||
var exercise PracticeExercise
|
||||
if err := json.Unmarshal([]byte(jsonStr), &exercise); err != nil {
|
||||
return &PracticeExercise{
|
||||
Prompt: topic,
|
||||
Instructions: "Practice this concept",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &exercise, nil
|
||||
}
|
||||
|
||||
func (l *StepByStepLesson) CompleteStep(stepIndex int) {
|
||||
if stepIndex < 0 || stepIndex >= len(l.Steps) {
|
||||
return
|
||||
}
|
||||
|
||||
l.Steps[stepIndex].Status = StatusCompleted
|
||||
|
||||
alreadyCompleted := false
|
||||
for _, idx := range l.Progress.CompletedSteps {
|
||||
if idx == stepIndex {
|
||||
alreadyCompleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !alreadyCompleted {
|
||||
l.Progress.CompletedSteps = append(l.Progress.CompletedSteps, stepIndex)
|
||||
l.Progress.Score += 10
|
||||
}
|
||||
|
||||
if stepIndex+1 < len(l.Steps) {
|
||||
l.Steps[stepIndex+1].Status = StatusAvailable
|
||||
l.Progress.CurrentStep = stepIndex + 1
|
||||
}
|
||||
|
||||
if len(l.Progress.CompletedSteps) == len(l.Steps) {
|
||||
l.Progress.Completed = true
|
||||
now := time.Now()
|
||||
l.Progress.CompletedAt = &now
|
||||
}
|
||||
|
||||
l.UpdatedAt = time.Now()
|
||||
l.Progress.LastAccessed = time.Now()
|
||||
}
|
||||
|
||||
func (l *StepByStepLesson) SubmitQuizAnswer(stepIndex int, selectedOptions []string) (bool, string) {
|
||||
if stepIndex < 0 || stepIndex >= len(l.Steps) {
|
||||
return false, "Invalid step"
|
||||
}
|
||||
|
||||
step := &l.Steps[stepIndex]
|
||||
if step.Quiz == nil {
|
||||
return false, "No quiz in this step"
|
||||
}
|
||||
|
||||
correctCount := 0
|
||||
totalCorrect := 0
|
||||
|
||||
for _, opt := range step.Quiz.Options {
|
||||
if opt.IsCorrect {
|
||||
totalCorrect++
|
||||
}
|
||||
}
|
||||
|
||||
for _, selected := range selectedOptions {
|
||||
for _, opt := range step.Quiz.Options {
|
||||
if opt.ID == selected && opt.IsCorrect {
|
||||
correctCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isCorrect := correctCount == totalCorrect && len(selectedOptions) == totalCorrect
|
||||
|
||||
if isCorrect {
|
||||
return true, step.Quiz.Explanation
|
||||
}
|
||||
|
||||
return false, step.Quiz.Explanation
|
||||
}
|
||||
|
||||
func extractJSON(text string) string {
|
||||
re := regexp.MustCompile(`(?s)\{.*\}`)
|
||||
match := re.FindString(text)
|
||||
if match != "" {
|
||||
return match
|
||||
}
|
||||
return "{}"
|
||||
}
|
||||
|
||||
func (l *StepByStepLesson) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(l)
|
||||
}
|
||||
|
||||
func ParseLesson(data []byte) (*StepByStepLesson, error) {
|
||||
var lesson StepByStepLesson
|
||||
if err := json.Unmarshal(data, &lesson); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &lesson, nil
|
||||
}
|
||||
182
backend/internal/llm/anthropic.go
Normal file
182
backend/internal/llm/anthropic.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AnthropicClient struct {
|
||||
baseClient
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewAnthropicClient(cfg ProviderConfig) (*AnthropicClient, error) {
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
|
||||
return &AnthropicClient{
|
||||
baseClient: baseClient{
|
||||
providerID: cfg.ProviderID,
|
||||
modelKey: cfg.ModelKey,
|
||||
},
|
||||
apiKey: cfg.APIKey,
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
client: &http.Client{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"delta,omitempty"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"content_block,omitempty"`
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
|
||||
var systemPrompt string
|
||||
messages := make([]anthropicMessage, 0)
|
||||
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == RoleSystem {
|
||||
systemPrompt = m.Content
|
||||
continue
|
||||
}
|
||||
role := string(m.Role)
|
||||
if role == "tool" {
|
||||
role = "user"
|
||||
}
|
||||
messages = append(messages, anthropicMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
maxTokens := req.Options.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
anthropicReq := anthropicRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
System: systemPrompt,
|
||||
MaxTokens: maxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
anthropicReq.Tools = make([]anthropicTool, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
anthropicReq.Tools[i] = anthropicTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.Schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("x-api-key", c.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("anthropic API error: %d - %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ch := make(chan StreamChunk, 100)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var event anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "content_block_delta":
|
||||
if event.Delta.Text != "" {
|
||||
ch <- StreamChunk{ContentChunk: event.Delta.Text}
|
||||
}
|
||||
case "message_stop":
|
||||
ch <- StreamChunk{FinishReason: "stop"}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
|
||||
ch, err := c.StreamText(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return readAllChunks(ch), nil
|
||||
}
|
||||
145
backend/internal/llm/client.go
Normal file
145
backend/internal/llm/client.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
RoleTool Role = "tool"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role Role `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Images []ImageContent `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
IsBase64 bool `json:"isBase64,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Schema interface{} `json:"schema"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
StopWords []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type StreamChunk struct {
|
||||
ContentChunk string `json:"content_chunk,omitempty"`
|
||||
ToolCallChunk []ToolCall `json:"tool_call_chunk,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type StreamRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Options StreamOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error)
|
||||
GenerateText(ctx context.Context, req StreamRequest) (string, error)
|
||||
GetProviderID() string
|
||||
GetModelKey() string
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
ProviderID string `json:"providerId"`
|
||||
ModelKey string `json:"key"`
|
||||
APIKey string `json:"apiKey,omitempty"`
|
||||
BaseURL string `json:"baseUrl,omitempty"`
|
||||
AgentAccessID string `json:"agentAccessId,omitempty"`
|
||||
}
|
||||
|
||||
func NewClient(cfg ProviderConfig) (Client, error) {
|
||||
switch cfg.ProviderID {
|
||||
case "timeweb":
|
||||
return NewTimewebClient(TimewebConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
AgentAccessID: cfg.AgentAccessID,
|
||||
APIKey: cfg.APIKey,
|
||||
ModelKey: cfg.ModelKey,
|
||||
ProxySource: "gooseek",
|
||||
})
|
||||
case "openai":
|
||||
return NewOpenAIClient(cfg)
|
||||
case "anthropic":
|
||||
return NewAnthropicClient(cfg)
|
||||
case "gemini", "google":
|
||||
return NewGeminiClient(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider: %s", cfg.ProviderID)
|
||||
}
|
||||
}
|
||||
|
||||
type baseClient struct {
|
||||
providerID string
|
||||
modelKey string
|
||||
}
|
||||
|
||||
func (c *baseClient) GetProviderID() string {
|
||||
return c.providerID
|
||||
}
|
||||
|
||||
func (c *baseClient) GetModelKey() string {
|
||||
return c.modelKey
|
||||
}
|
||||
|
||||
func readAllChunks(ch <-chan StreamChunk) string {
|
||||
var result string
|
||||
for chunk := range ch {
|
||||
result += chunk.ContentChunk
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type streamReader struct {
|
||||
ch <-chan StreamChunk
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (r *streamReader) Read(p []byte) (n int, err error) {
|
||||
if len(r.buffer) > 0 {
|
||||
n = copy(p, r.buffer)
|
||||
r.buffer = r.buffer[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
chunk, ok := <-r.ch
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
data := []byte(chunk.ContentChunk)
|
||||
n = copy(p, data)
|
||||
if n < len(data) {
|
||||
r.buffer = data[n:]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
193
backend/internal/llm/gemini.go
Normal file
193
backend/internal/llm/gemini.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GeminiClient struct {
|
||||
baseClient
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewGeminiClient(cfg ProviderConfig) (*GeminiClient, error) {
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
return &GeminiClient{
|
||||
baseClient: baseClient{
|
||||
providerID: cfg.ProviderID,
|
||||
modelKey: cfg.ModelKey,
|
||||
},
|
||||
apiKey: cfg.APIKey,
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
client: &http.Client{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type geminiRequest struct {
|
||||
Contents []geminiContent `json:"contents"`
|
||||
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig geminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []geminiTool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type geminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []geminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type geminiGenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
type geminiTool struct {
|
||||
FunctionDeclarations []geminiFunctionDecl `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type geminiFunctionDecl struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
type geminiStreamResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
func (c *GeminiClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
|
||||
contents := make([]geminiContent, 0)
|
||||
var systemInstruction *geminiContent
|
||||
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == RoleSystem {
|
||||
systemInstruction = &geminiContent{
|
||||
Parts: []geminiPart{{Text: m.Content}},
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
role := "user"
|
||||
if m.Role == RoleAssistant {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
contents = append(contents, geminiContent{
|
||||
Role: role,
|
||||
Parts: []geminiPart{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
|
||||
geminiReq := geminiRequest{
|
||||
Contents: contents,
|
||||
SystemInstruction: systemInstruction,
|
||||
GenerationConfig: geminiGenerationConfig{
|
||||
MaxOutputTokens: req.Options.MaxTokens,
|
||||
Temperature: req.Options.Temperature,
|
||||
TopP: req.Options.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
decls := make([]geminiFunctionDecl, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
decls[i] = geminiFunctionDecl{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Schema,
|
||||
}
|
||||
}
|
||||
geminiReq.Tools = []geminiTool{{FunctionDeclarations: decls}}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(geminiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.baseURL, c.modelKey, c.apiKey)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("gemini API error: %d - %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ch := make(chan StreamChunk, 100)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var response geminiStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &response); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(response.Candidates) > 0 {
|
||||
candidate := response.Candidates[0]
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
ch <- StreamChunk{ContentChunk: part.Text}
|
||||
}
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
ch <- StreamChunk{FinishReason: candidate.FinishReason}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *GeminiClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
|
||||
ch, err := c.StreamText(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return readAllChunks(ch), nil
|
||||
}
|
||||
166
backend/internal/llm/openai.go
Normal file
166
backend/internal/llm/openai.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
baseClient
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
func NewOpenAIClient(cfg ProviderConfig) (*OpenAIClient, error) {
|
||||
config := openai.DefaultConfig(cfg.APIKey)
|
||||
if cfg.BaseURL != "" {
|
||||
config.BaseURL = cfg.BaseURL
|
||||
}
|
||||
|
||||
return &OpenAIClient{
|
||||
baseClient: baseClient{
|
||||
providerID: cfg.ProviderID,
|
||||
modelKey: cfg.ModelKey,
|
||||
},
|
||||
client: openai.NewClientWithConfig(config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
|
||||
messages := make([]openai.ChatCompletionMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
msg := openai.ChatCompletionMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
if m.Name != "" {
|
||||
msg.Name = m.Name
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg.ToolCallID = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls))
|
||||
for i, tc := range m.ToolCalls {
|
||||
args, _ := json.Marshal(tc.Arguments)
|
||||
msg.ToolCalls[i] = openai.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: openai.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
chatReq := openai.ChatCompletionRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
if req.Options.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.Options.MaxTokens
|
||||
}
|
||||
if req.Options.Temperature > 0 {
|
||||
chatReq.Temperature = float32(req.Options.Temperature)
|
||||
}
|
||||
if req.Options.TopP > 0 {
|
||||
chatReq.TopP = float32(req.Options.TopP)
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
chatReq.Tools = make([]openai.Tool, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
chatReq.Tools[i] = openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Schema,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stream, err := c.client.CreateChatCompletionStream(ctx, chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan StreamChunk, 100)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
toolCalls := make(map[int]*ToolCall)
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if len(toolCalls) > 0 {
|
||||
calls := make([]ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
calls = append(calls, *tc)
|
||||
}
|
||||
ch <- StreamChunk{ToolCallChunk: calls}
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
|
||||
if delta.Content != "" {
|
||||
ch <- StreamChunk{ContentChunk: delta.Content}
|
||||
}
|
||||
|
||||
for _, tc := range delta.ToolCalls {
|
||||
idx := *tc.Index
|
||||
if _, ok := toolCalls[idx]; !ok {
|
||||
toolCalls[idx] = &ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
if tc.Function.Arguments != "" {
|
||||
existing := toolCalls[idx]
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil {
|
||||
for k, v := range args {
|
||||
existing.Arguments[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if response.Choices[0].FinishReason != "" {
|
||||
ch <- StreamChunk{FinishReason: string(response.Choices[0].FinishReason)}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
|
||||
ch, err := c.StreamText(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return readAllChunks(ch), nil
|
||||
}
|
||||
229
backend/internal/llm/registry.go
Normal file
229
backend/internal/llm/registry.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ModelCapability string
|
||||
|
||||
const (
|
||||
CapReasoning ModelCapability = "reasoning"
|
||||
CapCoding ModelCapability = "coding"
|
||||
CapSearch ModelCapability = "search"
|
||||
CapCreative ModelCapability = "creative"
|
||||
CapFast ModelCapability = "fast"
|
||||
CapLongContext ModelCapability = "long_context"
|
||||
CapVision ModelCapability = "vision"
|
||||
CapMath ModelCapability = "math"
|
||||
CapVideo ModelCapability = "video"
|
||||
CapImage ModelCapability = "image"
|
||||
)
|
||||
|
||||
type ModelSpec struct {
|
||||
ID string
|
||||
Provider string
|
||||
Model string
|
||||
Capabilities []ModelCapability
|
||||
CostPer1K float64
|
||||
MaxContext int
|
||||
Priority int
|
||||
MaxTokens int
|
||||
Description string
|
||||
}
|
||||
|
||||
func (m ModelSpec) HasCapability(cap ModelCapability) bool {
|
||||
for _, c := range m.Capabilities {
|
||||
if c == cap {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ModelRegistry struct {
|
||||
models map[string]ModelSpec
|
||||
clients map[string]Client
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewModelRegistry() *ModelRegistry {
|
||||
return &ModelRegistry{
|
||||
models: make(map[string]ModelSpec),
|
||||
clients: make(map[string]Client),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) Register(spec ModelSpec, client Client) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.models[spec.ID] = spec
|
||||
r.clients[spec.ID] = client
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) Unregister(id string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.models, id)
|
||||
delete(r.clients, id)
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) GetByID(id string) (Client, ModelSpec, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
spec, ok := r.models[id]
|
||||
if !ok {
|
||||
return nil, ModelSpec{}, errors.New("model not found: " + id)
|
||||
}
|
||||
|
||||
client, ok := r.clients[id]
|
||||
if !ok {
|
||||
return nil, ModelSpec{}, errors.New("client not found: " + id)
|
||||
}
|
||||
|
||||
return client, spec, nil
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) GetBest(cap ModelCapability) (Client, ModelSpec, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var candidates []ModelSpec
|
||||
for _, spec := range r.models {
|
||||
if spec.HasCapability(cap) {
|
||||
candidates = append(candidates, spec)
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, ModelSpec{}, errors.New("no model found with capability: " + string(cap))
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
if candidates[i].Priority != candidates[j].Priority {
|
||||
return candidates[i].Priority < candidates[j].Priority
|
||||
}
|
||||
return candidates[i].CostPer1K < candidates[j].CostPer1K
|
||||
})
|
||||
|
||||
best := candidates[0]
|
||||
client := r.clients[best.ID]
|
||||
return client, best, nil
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) GetAllWithCapability(cap ModelCapability) []ModelSpec {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []ModelSpec
|
||||
for _, spec := range r.models {
|
||||
if spec.HasCapability(cap) {
|
||||
result = append(result, spec)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Priority < result[j].Priority
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) GetAll() []ModelSpec {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]ModelSpec, 0, len(r.models))
|
||||
for _, spec := range r.models {
|
||||
result = append(result, spec)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) GetClient(id string) (Client, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
client, ok := r.clients[id]
|
||||
if !ok {
|
||||
return nil, errors.New("client not found: " + id)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.models)
|
||||
}
|
||||
|
||||
var DefaultModels = []ModelSpec{
|
||||
{
|
||||
ID: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
Capabilities: []ModelCapability{CapSearch, CapFast, CapVision, CapCoding, CapCreative},
|
||||
CostPer1K: 0.005,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 1,
|
||||
Description: "GPT-4o: fast multimodal model with search",
|
||||
},
|
||||
{
|
||||
ID: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Capabilities: []ModelCapability{CapFast, CapCoding},
|
||||
CostPer1K: 0.00015,
|
||||
MaxContext: 128000,
|
||||
MaxTokens: 16384,
|
||||
Priority: 2,
|
||||
Description: "GPT-4o Mini: cost-effective for simple tasks",
|
||||
},
|
||||
{
|
||||
ID: "claude-3-opus",
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-opus-20240229",
|
||||
Capabilities: []ModelCapability{CapReasoning, CapCoding, CapCreative, CapLongContext},
|
||||
CostPer1K: 0.015,
|
||||
MaxContext: 200000,
|
||||
MaxTokens: 4096,
|
||||
Priority: 1,
|
||||
Description: "Claude 3 Opus: best for complex reasoning and coding",
|
||||
},
|
||||
{
|
||||
ID: "claude-3-sonnet",
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Capabilities: []ModelCapability{CapCoding, CapCreative, CapFast},
|
||||
CostPer1K: 0.003,
|
||||
MaxContext: 200000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 1,
|
||||
Description: "Claude 3.5 Sonnet: balanced speed and quality",
|
||||
},
|
||||
{
|
||||
ID: "gemini-1.5-pro",
|
||||
Provider: "gemini",
|
||||
Model: "gemini-1.5-pro",
|
||||
Capabilities: []ModelCapability{CapLongContext, CapSearch, CapVision, CapMath},
|
||||
CostPer1K: 0.00125,
|
||||
MaxContext: 2000000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 1,
|
||||
Description: "Gemini 1.5 Pro: best for long context and research",
|
||||
},
|
||||
{
|
||||
ID: "gemini-1.5-flash",
|
||||
Provider: "gemini",
|
||||
Model: "gemini-1.5-flash",
|
||||
Capabilities: []ModelCapability{CapFast, CapVision},
|
||||
CostPer1K: 0.000075,
|
||||
MaxContext: 1000000,
|
||||
MaxTokens: 8192,
|
||||
Priority: 2,
|
||||
Description: "Gemini 1.5 Flash: fastest for lightweight tasks",
|
||||
},
|
||||
}
|
||||
402
backend/internal/llm/timeweb.go
Normal file
402
backend/internal/llm/timeweb.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimewebClient struct {
|
||||
baseClient
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
agentAccessID string
|
||||
apiKey string
|
||||
proxySource string
|
||||
}
|
||||
|
||||
type TimewebConfig struct {
|
||||
ProviderID string
|
||||
ModelKey string
|
||||
BaseURL string
|
||||
AgentAccessID string
|
||||
APIKey string
|
||||
ProxySource string
|
||||
}
|
||||
|
||||
func NewTimewebClient(cfg TimewebConfig) (*TimewebClient, error) {
|
||||
if cfg.AgentAccessID == "" {
|
||||
return nil, errors.New("agent_access_id is required for Timeweb")
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
return nil, errors.New("api_key is required for Timeweb")
|
||||
}
|
||||
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.timeweb.cloud"
|
||||
}
|
||||
|
||||
proxySource := cfg.ProxySource
|
||||
if proxySource == "" {
|
||||
proxySource = "gooseek"
|
||||
}
|
||||
|
||||
return &TimewebClient{
|
||||
baseClient: baseClient{
|
||||
providerID: cfg.ProviderID,
|
||||
modelKey: cfg.ModelKey,
|
||||
},
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
baseURL: baseURL,
|
||||
agentAccessID: cfg.AgentAccessID,
|
||||
apiKey: cfg.APIKey,
|
||||
proxySource: proxySource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type timewebChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []timewebMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []timewebTool `json:"tools,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type timewebMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type timewebTool struct {
|
||||
Type string `json:"type"`
|
||||
Function timewebFunction `json:"function"`
|
||||
}
|
||||
|
||||
type timewebFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
type timewebToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type timewebChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type timewebStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []timewebToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func (c *TimewebClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
|
||||
messages := make([]timewebMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
msg := timewebMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
if m.Name != "" {
|
||||
msg.Name = m.Name
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg.ToolCallID = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg.ToolCalls = make([]timewebToolCall, len(m.ToolCalls))
|
||||
for i, tc := range m.ToolCalls {
|
||||
args, _ := json.Marshal(tc.Arguments)
|
||||
msg.ToolCalls[i] = timewebToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
}
|
||||
msg.ToolCalls[i].Function.Name = tc.Name
|
||||
msg.ToolCalls[i].Function.Arguments = string(args)
|
||||
}
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
chatReq := timewebChatRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
if req.Options.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.Options.MaxTokens
|
||||
}
|
||||
if req.Options.Temperature > 0 {
|
||||
chatReq.Temperature = req.Options.Temperature
|
||||
}
|
||||
if req.Options.TopP > 0 {
|
||||
chatReq.TopP = req.Options.TopP
|
||||
}
|
||||
if len(req.Options.StopWords) > 0 {
|
||||
chatReq.Stop = req.Options.StopWords
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
chatReq.Tools = make([]timewebTool, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
chatReq.Tools[i] = timewebTool{
|
||||
Type: "function",
|
||||
Function: timewebFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Schema,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(chatReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/cloud-ai/agents/%s/v1/chat/completions", c.baseURL, c.agentAccessID)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
httpReq.Header.Set("x-proxy-source", c.proxySource)
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Timeweb API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ch := make(chan StreamChunk, 100)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
toolCalls := make(map[int]*ToolCall)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
return
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
calls := make([]ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
calls = append(calls, *tc)
|
||||
}
|
||||
ch <- StreamChunk{ToolCallChunk: calls}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
if len(toolCalls) > 0 {
|
||||
calls := make([]ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
calls = append(calls, *tc)
|
||||
}
|
||||
ch <- StreamChunk{ToolCallChunk: calls}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var streamResp timewebStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(streamResp.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
delta := streamResp.Choices[0].Delta
|
||||
|
||||
if delta.Content != "" {
|
||||
ch <- StreamChunk{ContentChunk: delta.Content}
|
||||
}
|
||||
|
||||
for _, tc := range delta.ToolCalls {
|
||||
idx := 0
|
||||
if _, ok := toolCalls[idx]; !ok {
|
||||
toolCalls[idx] = &ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
if tc.Function.Arguments != "" {
|
||||
existing := toolCalls[idx]
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil {
|
||||
for k, v := range args {
|
||||
existing.Arguments[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamResp.Choices[0].FinishReason != "" {
|
||||
ch <- StreamChunk{FinishReason: streamResp.Choices[0].FinishReason}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *TimewebClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
|
||||
messages := make([]timewebMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
msg := timewebMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
if m.Name != "" {
|
||||
msg.Name = m.Name
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg.ToolCallID = m.ToolCallID
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
chatReq := timewebChatRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
if req.Options.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.Options.MaxTokens
|
||||
}
|
||||
if req.Options.Temperature > 0 {
|
||||
chatReq.Temperature = req.Options.Temperature
|
||||
}
|
||||
if req.Options.TopP > 0 {
|
||||
chatReq.TopP = req.Options.TopP
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
chatReq.Tools = make([]timewebTool, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
chatReq.Tools[i] = timewebTool{
|
||||
Type: "function",
|
||||
Function: timewebFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Schema,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(chatReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/cloud-ai/agents/%s/v1/chat/completions", c.baseURL, c.agentAccessID)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
httpReq.Header.Set("x-proxy-source", c.proxySource)
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("Timeweb API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var chatResp timewebChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", errors.New("no choices in response")
|
||||
}
|
||||
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
318
backend/internal/pages/generator.go
Normal file
318
backend/internal/pages/generator.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package pages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Page struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
ThreadID string `json:"threadId,omitempty"`
|
||||
Title string `json:"title"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
Sections []PageSection `json:"sections"`
|
||||
Sources []PageSource `json:"sources"`
|
||||
Thumbnail string `json:"thumbnail,omitempty"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
ShareID string `json:"shareId,omitempty"`
|
||||
ViewCount int `json:"viewCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type PageSection struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Content string `json:"content"`
|
||||
ImageURL string `json:"imageUrl,omitempty"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
type PageSource struct {
|
||||
Index int `json:"index"`
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Domain string `json:"domain"`
|
||||
Favicon string `json:"favicon,omitempty"`
|
||||
}
|
||||
|
||||
type PageGeneratorConfig struct {
|
||||
LLMClient llm.Client
|
||||
Locale string
|
||||
Style string
|
||||
Audience string
|
||||
}
|
||||
|
||||
type PageGenerator struct {
|
||||
cfg PageGeneratorConfig
|
||||
}
|
||||
|
||||
func NewPageGenerator(cfg PageGeneratorConfig) *PageGenerator {
|
||||
return &PageGenerator{cfg: cfg}
|
||||
}
|
||||
|
||||
func (g *PageGenerator) GenerateFromThread(ctx context.Context, query string, answer string, sources []types.Chunk) (*Page, error) {
|
||||
structurePrompt := g.buildStructurePrompt(query, answer, sources)
|
||||
|
||||
structure, err := g.cfg.LLMClient.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{
|
||||
{Role: "user", Content: structurePrompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate structure: %w", err)
|
||||
}
|
||||
|
||||
page := g.parseStructure(structure)
|
||||
page.ID = uuid.New().String()
|
||||
page.CreatedAt = time.Now()
|
||||
page.UpdatedAt = time.Now()
|
||||
|
||||
for i, src := range sources {
|
||||
if i >= 20 {
|
||||
break
|
||||
}
|
||||
url := src.Metadata["url"]
|
||||
title := src.Metadata["title"]
|
||||
page.Sources = append(page.Sources, PageSource{
|
||||
Index: i + 1,
|
||||
URL: url,
|
||||
Title: title,
|
||||
Domain: extractDomain(url),
|
||||
})
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (g *PageGenerator) buildStructurePrompt(query, answer string, sources []types.Chunk) string {
|
||||
var sourcesText strings.Builder
|
||||
for i, s := range sources {
|
||||
if i >= 15 {
|
||||
break
|
||||
}
|
||||
sourcesText.WriteString(fmt.Sprintf("[%d] %s\n%s\n\n", i+1, s.Metadata["title"], truncate(s.Content, 300)))
|
||||
}
|
||||
|
||||
langInstr := ""
|
||||
if g.cfg.Locale == "ru" {
|
||||
langInstr = "Write in Russian."
|
||||
}
|
||||
|
||||
style := g.cfg.Style
|
||||
if style == "" {
|
||||
style = "informative"
|
||||
}
|
||||
|
||||
audience := g.cfg.Audience
|
||||
if audience == "" {
|
||||
audience = "general"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`Create a well-structured article from this research.
|
||||
|
||||
Topic: %s
|
||||
|
||||
Research findings:
|
||||
%s
|
||||
|
||||
Sources:
|
||||
%s
|
||||
|
||||
%s
|
||||
|
||||
Style: %s
|
||||
Target audience: %s
|
||||
|
||||
Generate the article in this exact format:
|
||||
|
||||
TITLE: [compelling title]
|
||||
SUBTITLE: [brief subtitle]
|
||||
|
||||
SECTION: Introduction
|
||||
[2-3 paragraphs introducing the topic]
|
||||
|
||||
SECTION: [Topic Name 1]
|
||||
[detailed content with citations [1], [2], etc.]
|
||||
|
||||
SECTION: [Topic Name 2]
|
||||
[detailed content with citations]
|
||||
|
||||
SECTION: [Topic Name 3]
|
||||
[detailed content with citations]
|
||||
|
||||
SECTION: Conclusion
|
||||
[summary and key takeaways]
|
||||
|
||||
SECTION: Key Points
|
||||
- [bullet point 1]
|
||||
- [bullet point 2]
|
||||
- [bullet point 3]
|
||||
|
||||
Requirements:
|
||||
- Use citations [1], [2], etc. throughout
|
||||
- Make it comprehensive but readable
|
||||
- Include specific facts and data
|
||||
- Keep sections focused and well-organized`, query, truncate(answer, 2000), sourcesText.String(), langInstr, style, audience)
|
||||
}
|
||||
|
||||
func (g *PageGenerator) parseStructure(text string) *Page {
|
||||
page := &Page{
|
||||
Sections: make([]PageSection, 0),
|
||||
}
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
var currentSection *PageSection
|
||||
var contentBuilder strings.Builder
|
||||
order := 0
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.HasPrefix(line, "TITLE:") {
|
||||
page.Title = strings.TrimSpace(strings.TrimPrefix(line, "TITLE:"))
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "SUBTITLE:") {
|
||||
page.Subtitle = strings.TrimSpace(strings.TrimPrefix(line, "SUBTITLE:"))
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "SECTION:") {
|
||||
if currentSection != nil {
|
||||
currentSection.Content = strings.TrimSpace(contentBuilder.String())
|
||||
page.Sections = append(page.Sections, *currentSection)
|
||||
contentBuilder.Reset()
|
||||
}
|
||||
|
||||
order++
|
||||
currentSection = &PageSection{
|
||||
ID: uuid.New().String(),
|
||||
Type: "text",
|
||||
Title: strings.TrimSpace(strings.TrimPrefix(line, "SECTION:")),
|
||||
Order: order,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentSection != nil {
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
if currentSection != nil {
|
||||
currentSection.Content = strings.TrimSpace(contentBuilder.String())
|
||||
page.Sections = append(page.Sections, *currentSection)
|
||||
}
|
||||
|
||||
return page
|
||||
}
|
||||
|
||||
func (g *PageGenerator) ExportToMarkdown(page *Page) string {
|
||||
var md strings.Builder
|
||||
|
||||
md.WriteString("# " + page.Title + "\n\n")
|
||||
if page.Subtitle != "" {
|
||||
md.WriteString("*" + page.Subtitle + "*\n\n")
|
||||
}
|
||||
|
||||
for _, section := range page.Sections {
|
||||
md.WriteString("## " + section.Title + "\n\n")
|
||||
md.WriteString(section.Content + "\n\n")
|
||||
}
|
||||
|
||||
md.WriteString("---\n\n## Sources\n\n")
|
||||
for _, src := range page.Sources {
|
||||
md.WriteString(fmt.Sprintf("%d. [%s](%s)\n", src.Index, src.Title, src.URL))
|
||||
}
|
||||
|
||||
return md.String()
|
||||
}
|
||||
|
||||
func (g *PageGenerator) ExportToHTML(page *Page) string {
|
||||
var html strings.Builder
|
||||
|
||||
html.WriteString("<!DOCTYPE html>\n<html>\n<head>\n")
|
||||
html.WriteString(fmt.Sprintf("<title>%s</title>\n", page.Title))
|
||||
html.WriteString("<style>\n")
|
||||
html.WriteString(`body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; line-height: 1.6; }
|
||||
h1 { color: #1a1a1a; border-bottom: 2px solid #007bff; padding-bottom: 10px; }
|
||||
h2 { color: #333; margin-top: 30px; }
|
||||
.subtitle { color: #666; font-style: italic; margin-bottom: 30px; }
|
||||
.sources { background: #f5f5f5; padding: 20px; border-radius: 8px; margin-top: 40px; }
|
||||
.sources a { color: #007bff; text-decoration: none; }
|
||||
.sources a:hover { text-decoration: underline; }
|
||||
`)
|
||||
html.WriteString("</style>\n</head>\n<body>\n")
|
||||
|
||||
html.WriteString(fmt.Sprintf("<h1>%s</h1>\n", page.Title))
|
||||
if page.Subtitle != "" {
|
||||
html.WriteString(fmt.Sprintf("<p class=\"subtitle\">%s</p>\n", page.Subtitle))
|
||||
}
|
||||
|
||||
for _, section := range page.Sections {
|
||||
html.WriteString(fmt.Sprintf("<h2>%s</h2>\n", section.Title))
|
||||
paragraphs := strings.Split(section.Content, "\n\n")
|
||||
for _, p := range paragraphs {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
if strings.HasPrefix(p, "- ") {
|
||||
html.WriteString("<ul>\n")
|
||||
for _, item := range strings.Split(p, "\n") {
|
||||
item = strings.TrimPrefix(item, "- ")
|
||||
html.WriteString(fmt.Sprintf("<li>%s</li>\n", item))
|
||||
}
|
||||
html.WriteString("</ul>\n")
|
||||
} else {
|
||||
html.WriteString(fmt.Sprintf("<p>%s</p>\n", p))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
html.WriteString("<div class=\"sources\">\n<h3>Sources</h3>\n<ol>\n")
|
||||
for _, src := range page.Sources {
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"%s\" target=\"_blank\">%s</a> (%s)</li>\n", src.URL, src.Title, src.Domain))
|
||||
}
|
||||
html.WriteString("</ol>\n</div>\n")
|
||||
|
||||
html.WriteString("</body>\n</html>")
|
||||
|
||||
return html.String()
|
||||
}
|
||||
|
||||
func (g *PageGenerator) ToJSON(page *Page) (string, error) {
|
||||
data, err := json.MarshalIndent(page, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func extractDomain(url string) string {
|
||||
url = strings.TrimPrefix(url, "https://")
|
||||
url = strings.TrimPrefix(url, "http://")
|
||||
url = strings.TrimPrefix(url, "www.")
|
||||
if idx := strings.Index(url, "/"); idx > 0 {
|
||||
return url[:idx]
|
||||
}
|
||||
return url
|
||||
}
|
||||
507
backend/internal/podcast/generator.go
Normal file
507
backend/internal/podcast/generator.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package podcast
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/llm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type PodcastType string
|
||||
|
||||
const (
|
||||
PodcastDaily PodcastType = "daily"
|
||||
PodcastWeekly PodcastType = "weekly"
|
||||
PodcastTopicDeep PodcastType = "topic_deep"
|
||||
PodcastBreaking PodcastType = "breaking"
|
||||
)
|
||||
|
||||
type VoiceStyle string
|
||||
|
||||
const (
|
||||
VoiceNeutral VoiceStyle = "neutral"
|
||||
VoiceEnthusiastic VoiceStyle = "enthusiastic"
|
||||
VoiceProfessional VoiceStyle = "professional"
|
||||
VoiceCasual VoiceStyle = "casual"
|
||||
VoiceStorytelling VoiceStyle = "storytelling"
|
||||
)
|
||||
|
||||
type Podcast struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Type PodcastType `json:"type"`
|
||||
Date time.Time `json:"date"`
|
||||
Duration int `json:"durationSeconds"`
|
||||
AudioURL string `json:"audioUrl,omitempty"`
|
||||
Transcript string `json:"transcript"`
|
||||
Segments []PodcastSegment `json:"segments"`
|
||||
Topics []string `json:"topics"`
|
||||
Sources []Source `json:"sources"`
|
||||
Thumbnail string `json:"thumbnail,omitempty"`
|
||||
Status PodcastStatus `json:"status"`
|
||||
GeneratedAt time.Time `json:"generatedAt"`
|
||||
PublishedAt *time.Time `json:"publishedAt,omitempty"`
|
||||
Locale string `json:"locale"`
|
||||
VoiceConfig VoiceConfig `json:"voiceConfig"`
|
||||
}
|
||||
|
||||
type PodcastStatus string
|
||||
|
||||
const (
|
||||
StatusDraft PodcastStatus = "draft"
|
||||
StatusGenerating PodcastStatus = "generating"
|
||||
StatusReady PodcastStatus = "ready"
|
||||
StatusPublished PodcastStatus = "published"
|
||||
StatusFailed PodcastStatus = "failed"
|
||||
)
|
||||
|
||||
type PodcastSegment struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Duration int `json:"durationSeconds"`
|
||||
StartTime int `json:"startTime"`
|
||||
EndTime int `json:"endTime"`
|
||||
Sources []Source `json:"sources,omitempty"`
|
||||
Highlights []string `json:"highlights,omitempty"`
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Publisher string `json:"publisher"`
|
||||
Date string `json:"date,omitempty"`
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
Provider string `json:"provider"`
|
||||
VoiceID string `json:"voiceId"`
|
||||
Style VoiceStyle `json:"style"`
|
||||
Speed float64 `json:"speed"`
|
||||
Pitch float64 `json:"pitch"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
type PodcastGenerator struct {
|
||||
llm llm.Client
|
||||
ttsClient TTSClient
|
||||
httpClient *http.Client
|
||||
config GeneratorConfig
|
||||
}
|
||||
|
||||
type GeneratorConfig struct {
|
||||
DefaultDuration int
|
||||
MaxDuration int
|
||||
DefaultVoice VoiceConfig
|
||||
OutputDir string
|
||||
}
|
||||
|
||||
type TTSClient interface {
|
||||
GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error)
|
||||
}
|
||||
|
||||
func NewPodcastGenerator(llmClient llm.Client, ttsClient TTSClient, cfg GeneratorConfig) *PodcastGenerator {
|
||||
if cfg.DefaultDuration == 0 {
|
||||
cfg.DefaultDuration = 300
|
||||
}
|
||||
if cfg.MaxDuration == 0 {
|
||||
cfg.MaxDuration = 1800
|
||||
}
|
||||
if cfg.DefaultVoice.Provider == "" {
|
||||
cfg.DefaultVoice = VoiceConfig{
|
||||
Provider: "elevenlabs",
|
||||
VoiceID: "21m00Tcm4TlvDq8ikWAM",
|
||||
Style: VoiceNeutral,
|
||||
Speed: 1.0,
|
||||
Pitch: 1.0,
|
||||
Language: "ru",
|
||||
}
|
||||
}
|
||||
|
||||
return &PodcastGenerator{
|
||||
llm: llmClient,
|
||||
ttsClient: ttsClient,
|
||||
httpClient: &http.Client{Timeout: 60 * time.Second},
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
type GenerateOptions struct {
|
||||
Type PodcastType
|
||||
Topics []string
|
||||
NewsItems []NewsItem
|
||||
Date time.Time
|
||||
Duration int
|
||||
Locale string
|
||||
VoiceConfig *VoiceConfig
|
||||
IncludeIntro bool
|
||||
IncludeOutro bool
|
||||
PersonalizeFor string
|
||||
}
|
||||
|
||||
type NewsItem struct {
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
URL string `json:"url"`
|
||||
Source string `json:"source"`
|
||||
PublishedAt string `json:"publishedAt"`
|
||||
Topics []string `json:"topics"`
|
||||
Importance int `json:"importance"`
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) GenerateDailyPodcast(ctx context.Context, opts GenerateOptions) (*Podcast, error) {
|
||||
if opts.Date.IsZero() {
|
||||
opts.Date = time.Now()
|
||||
}
|
||||
if opts.Duration == 0 {
|
||||
opts.Duration = g.config.DefaultDuration
|
||||
}
|
||||
if opts.Locale == "" {
|
||||
opts.Locale = "ru"
|
||||
}
|
||||
|
||||
script, err := g.generateScript(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate script: %w", err)
|
||||
}
|
||||
|
||||
podcast := &Podcast{
|
||||
ID: uuid.New().String(),
|
||||
Title: script.Title,
|
||||
Description: script.Description,
|
||||
Type: opts.Type,
|
||||
Date: opts.Date,
|
||||
Duration: opts.Duration,
|
||||
Transcript: script.FullText,
|
||||
Segments: script.Segments,
|
||||
Topics: opts.Topics,
|
||||
Sources: script.Sources,
|
||||
Status: StatusDraft,
|
||||
GeneratedAt: time.Now(),
|
||||
Locale: opts.Locale,
|
||||
VoiceConfig: g.config.DefaultVoice,
|
||||
}
|
||||
|
||||
if opts.VoiceConfig != nil {
|
||||
podcast.VoiceConfig = *opts.VoiceConfig
|
||||
}
|
||||
|
||||
return podcast, nil
|
||||
}
|
||||
|
||||
type PodcastScript struct {
|
||||
Title string
|
||||
Description string
|
||||
FullText string
|
||||
Segments []PodcastSegment
|
||||
Sources []Source
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) generateScript(ctx context.Context, opts GenerateOptions) (*PodcastScript, error) {
|
||||
locale := opts.Locale
|
||||
langInstruction := ""
|
||||
if locale == "ru" {
|
||||
langInstruction = "Generate the entire script in Russian language. Use natural Russian speech patterns."
|
||||
}
|
||||
|
||||
newsJSON, _ := json.Marshal(opts.NewsItems)
|
||||
|
||||
prompt := fmt.Sprintf(`Create a podcast script for a daily news digest.
|
||||
|
||||
Date: %s
|
||||
Duration target: %d seconds (approximately %d minutes)
|
||||
Topics: %v
|
||||
%s
|
||||
|
||||
News items to cover:
|
||||
%s
|
||||
|
||||
Create an engaging podcast script with these requirements:
|
||||
1. Start with a catchy introduction greeting the audience
|
||||
2. Cover the most important news first
|
||||
3. Transition smoothly between stories
|
||||
4. Add brief analysis or context where appropriate
|
||||
5. End with a summary and sign-off
|
||||
|
||||
The script should sound natural when read aloud - use conversational language, not formal news anchor style.
|
||||
|
||||
Respond in JSON format:
|
||||
{
|
||||
"title": "Podcast title for this episode",
|
||||
"description": "Brief episode description",
|
||||
"segments": [
|
||||
{
|
||||
"type": "intro|news|analysis|transition|outro",
|
||||
"title": "Segment title",
|
||||
"content": "Full text to be spoken",
|
||||
"highlights": ["Key point 1", "Key point 2"],
|
||||
"sources": [{"title": "Source title", "url": "url", "publisher": "publisher"}]
|
||||
}
|
||||
]
|
||||
}`, opts.Date.Format("2006-01-02"), opts.Duration, opts.Duration/60, opts.Topics, langInstruction, string(newsJSON))
|
||||
|
||||
result, err := g.llm.GenerateText(ctx, llm.StreamRequest{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonStr := extractJSON(result)
|
||||
|
||||
var parsed struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Segments []struct {
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Highlights []string `json:"highlights"`
|
||||
Sources []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Publisher string `json:"publisher"`
|
||||
} `json:"sources"`
|
||||
} `json:"segments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
|
||||
return g.generateDefaultScript(opts)
|
||||
}
|
||||
|
||||
script := &PodcastScript{
|
||||
Title: parsed.Title,
|
||||
Description: parsed.Description,
|
||||
Segments: make([]PodcastSegment, 0),
|
||||
Sources: make([]Source, 0),
|
||||
}
|
||||
|
||||
var fullTextBuilder strings.Builder
|
||||
currentTime := 0
|
||||
avgWordsPerSecond := 2.5
|
||||
|
||||
for i, seg := range parsed.Segments {
|
||||
wordCount := len(strings.Fields(seg.Content))
|
||||
segDuration := int(float64(wordCount) / avgWordsPerSecond)
|
||||
if segDuration < 10 {
|
||||
segDuration = 10
|
||||
}
|
||||
|
||||
segment := PodcastSegment{
|
||||
ID: uuid.New().String(),
|
||||
Type: seg.Type,
|
||||
Title: seg.Title,
|
||||
Content: seg.Content,
|
||||
Duration: segDuration,
|
||||
StartTime: currentTime,
|
||||
EndTime: currentTime + segDuration,
|
||||
Highlights: seg.Highlights,
|
||||
}
|
||||
|
||||
for _, src := range seg.Sources {
|
||||
source := Source{
|
||||
Title: src.Title,
|
||||
URL: src.URL,
|
||||
Publisher: src.Publisher,
|
||||
}
|
||||
segment.Sources = append(segment.Sources, source)
|
||||
script.Sources = append(script.Sources, source)
|
||||
}
|
||||
|
||||
script.Segments = append(script.Segments, segment)
|
||||
|
||||
fullTextBuilder.WriteString(seg.Content)
|
||||
if i < len(parsed.Segments)-1 {
|
||||
fullTextBuilder.WriteString("\n\n")
|
||||
}
|
||||
|
||||
currentTime += segDuration
|
||||
}
|
||||
|
||||
script.FullText = fullTextBuilder.String()
|
||||
|
||||
return script, nil
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) generateDefaultScript(opts GenerateOptions) (*PodcastScript, error) {
|
||||
date := opts.Date.Format("2 January 2006")
|
||||
|
||||
intro := fmt.Sprintf("Добрый день! С вами GooSeek Daily — ваш ежедневный подкаст с главными новостями. Сегодня %s, и вот что происходит в мире.", date)
|
||||
|
||||
var newsContent strings.Builder
|
||||
for i, news := range opts.NewsItems {
|
||||
if i > 0 {
|
||||
newsContent.WriteString("\n\n")
|
||||
}
|
||||
newsContent.WriteString(fmt.Sprintf("%s. %s", news.Title, news.Summary))
|
||||
}
|
||||
|
||||
outro := "На этом всё на сегодня. Спасибо, что слушаете GooSeek Daily! Подписывайтесь на наш подкаст и до встречи завтра."
|
||||
|
||||
return &PodcastScript{
|
||||
Title: fmt.Sprintf("GooSeek Daily — %s", date),
|
||||
Description: "Ежедневный подкаст с главными новостями",
|
||||
FullText: fmt.Sprintf("%s\n\n%s\n\n%s", intro, newsContent.String(), outro),
|
||||
Segments: []PodcastSegment{
|
||||
{ID: uuid.New().String(), Type: "intro", Title: "Вступление", Content: intro, Duration: 15},
|
||||
{ID: uuid.New().String(), Type: "news", Title: "Новости", Content: newsContent.String(), Duration: opts.Duration - 30},
|
||||
{ID: uuid.New().String(), Type: "outro", Title: "Завершение", Content: outro, Duration: 15},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) GenerateAudio(ctx context.Context, podcast *Podcast) ([]byte, error) {
|
||||
if g.ttsClient == nil {
|
||||
return nil, fmt.Errorf("TTS client not configured")
|
||||
}
|
||||
|
||||
podcast.Status = StatusGenerating
|
||||
|
||||
audioData, err := g.ttsClient.GenerateSpeech(ctx, podcast.Transcript, podcast.VoiceConfig)
|
||||
if err != nil {
|
||||
podcast.Status = StatusFailed
|
||||
return nil, fmt.Errorf("failed to generate audio: %w", err)
|
||||
}
|
||||
|
||||
podcast.Status = StatusReady
|
||||
|
||||
return audioData, nil
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) GenerateWeeklySummary(ctx context.Context, weeklyNews []NewsItem, locale string) (*Podcast, error) {
|
||||
return g.GenerateDailyPodcast(ctx, GenerateOptions{
|
||||
Type: PodcastWeekly,
|
||||
NewsItems: weeklyNews,
|
||||
Duration: 900,
|
||||
Locale: locale,
|
||||
IncludeIntro: true,
|
||||
IncludeOutro: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (g *PodcastGenerator) GenerateTopicDeepDive(ctx context.Context, topic string, articles []NewsItem, locale string) (*Podcast, error) {
|
||||
return g.GenerateDailyPodcast(ctx, GenerateOptions{
|
||||
Type: PodcastTopicDeep,
|
||||
Topics: []string{topic},
|
||||
NewsItems: articles,
|
||||
Duration: 600,
|
||||
Locale: locale,
|
||||
IncludeIntro: true,
|
||||
IncludeOutro: true,
|
||||
})
|
||||
}
|
||||
|
||||
func extractJSON(text string) string {
|
||||
start := strings.Index(text, "{")
|
||||
if start == -1 {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
depth := 0
|
||||
for i := start; i < len(text); i++ {
|
||||
if text[i] == '{' {
|
||||
depth++
|
||||
} else if text[i] == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return text[start : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "{}"
|
||||
}
|
||||
|
||||
func (p *Podcast) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func ParsePodcast(data []byte) (*Podcast, error) {
|
||||
var podcast Podcast
|
||||
if err := json.Unmarshal(data, &podcast); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &podcast, nil
|
||||
}
|
||||
|
||||
type ElevenLabsTTS struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewElevenLabsTTS(apiKey string) *ElevenLabsTTS {
|
||||
return &ElevenLabsTTS{
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{Timeout: 120 * time.Second},
|
||||
baseURL: "https://api.elevenlabs.io/v1",
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ElevenLabsTTS) GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error) {
|
||||
voiceID := config.VoiceID
|
||||
if voiceID == "" {
|
||||
voiceID = "21m00Tcm4TlvDq8ikWAM"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/text-to-speech/%s", t.baseURL, voiceID)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"text": text,
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"voice_settings": map[string]interface{}{
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
"style": 0.5,
|
||||
"use_speaker_boost": true,
|
||||
},
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyJSON)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("xi-api-key", t.apiKey)
|
||||
req.Header.Set("Accept", "audio/mpeg")
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("ElevenLabs API error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var audioData []byte
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
audioData = append(audioData, buf[:n]...)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return audioData, nil
|
||||
}
|
||||
|
||||
type DummyTTS struct{}
|
||||
|
||||
func (t *DummyTTS) GenerateSpeech(ctx context.Context, text string, config VoiceConfig) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
50
backend/internal/prompts/classifier.go
Normal file
50
backend/internal/prompts/classifier.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package prompts
|
||||
|
||||
import "strings"
|
||||
|
||||
func GetClassifierPrompt(locale, detectedLang string) string {
|
||||
langInstruction := "Respond in the same language as the user's query."
|
||||
if detectedLang == "ru" {
|
||||
langInstruction = "The user is writing in Russian. Process accordingly."
|
||||
}
|
||||
|
||||
return strings.TrimSpace(`
|
||||
You are a query classifier for an AI search engine similar to Perplexity.
|
||||
|
||||
Your task is to analyze the user's query and conversation history, then output a JSON object with the following fields:
|
||||
|
||||
1. "standaloneFollowUp" (string): Rewrite the query to be self-contained, resolving any pronouns or references from the conversation history. If the query is already standalone, return it as-is.
|
||||
|
||||
2. "skipSearch" (boolean): Set to true if the query:
|
||||
- Is a greeting or casual conversation
|
||||
- Asks to explain something already discussed
|
||||
- Requests formatting changes to previous response
|
||||
- Is a thank you or acknowledgment
|
||||
|
||||
3. "topics" (array of strings): Key topics or entities mentioned in the query.
|
||||
|
||||
4. "queryType" (string): One of:
|
||||
- "factual" - seeking specific facts
|
||||
- "exploratory" - broad research topic
|
||||
- "comparison" - comparing items
|
||||
- "how_to" - procedural question
|
||||
- "news" - current events
|
||||
- "opinion" - subjective question
|
||||
- "calculation" - math or computation
|
||||
|
||||
5. "engines" (array of strings): Suggested search engines based on query type.
|
||||
|
||||
` + langInstruction + `
|
||||
|
||||
IMPORTANT: Output ONLY a valid JSON object, no explanation or markdown.
|
||||
|
||||
Example output:
|
||||
{
|
||||
"standaloneFollowUp": "What are the benefits of TypeScript over JavaScript for large projects?",
|
||||
"skipSearch": false,
|
||||
"topics": ["TypeScript", "JavaScript", "programming"],
|
||||
"queryType": "comparison",
|
||||
"engines": ["google", "duckduckgo"]
|
||||
}
|
||||
`)
|
||||
}
|
||||
127
backend/internal/prompts/researcher.go
Normal file
127
backend/internal/prompts/researcher.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ResearcherConfig struct {
|
||||
AvailableActions string
|
||||
Mode string
|
||||
Iteration int
|
||||
MaxIterations int
|
||||
Locale string
|
||||
DetectedLanguage string
|
||||
IsArticleSummary bool
|
||||
}
|
||||
|
||||
func GetResearcherPrompt(cfg ResearcherConfig) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("You are a research agent for GooSeek, an AI search engine.\n\n")
|
||||
|
||||
sb.WriteString("## Your Role\n\n")
|
||||
sb.WriteString("You gather information to answer user queries by:\n")
|
||||
sb.WriteString("1. Searching the web for relevant information\n")
|
||||
sb.WriteString("2. Scraping specific pages for detailed content\n")
|
||||
sb.WriteString("3. Deciding when you have enough information\n\n")
|
||||
|
||||
sb.WriteString("## Available Actions\n\n")
|
||||
sb.WriteString(cfg.AvailableActions)
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
sb.WriteString("## Progress\n\n")
|
||||
sb.WriteString(fmt.Sprintf("Current iteration: %d / %d\n\n", cfg.Iteration+1, cfg.MaxIterations))
|
||||
|
||||
switch cfg.Mode {
|
||||
case "speed":
|
||||
sb.WriteString("## Speed Mode\n\n")
|
||||
sb.WriteString("- Perform ONE search and call done\n")
|
||||
sb.WriteString("- Do NOT scrape pages\n")
|
||||
sb.WriteString("- Use snippets from search results\n\n")
|
||||
case "balanced":
|
||||
sb.WriteString("## Balanced Mode\n\n")
|
||||
sb.WriteString("- Perform 1-3 searches\n")
|
||||
sb.WriteString("- Scrape top 3-5 relevant pages\n")
|
||||
sb.WriteString("- Balance depth vs. speed\n\n")
|
||||
case "quality":
|
||||
sb.WriteString("## Quality Mode\n\n")
|
||||
sb.WriteString("- Perform multiple searches with different queries\n")
|
||||
sb.WriteString("- Scrape 10-15 relevant pages\n")
|
||||
sb.WriteString("- Verify information across sources\n")
|
||||
sb.WriteString("- Be thorough and comprehensive\n\n")
|
||||
}
|
||||
|
||||
if cfg.IsArticleSummary {
|
||||
sb.WriteString("## Article Summary Task (Perplexity Discover-style)\n\n")
|
||||
sb.WriteString("The user requested an article summary (Summary: <url>). This is a multi-source digest request.\n\n")
|
||||
sb.WriteString("**Your goals:**\n")
|
||||
sb.WriteString("1. The main article is already pre-scraped and will be in context\n")
|
||||
sb.WriteString("2. Search for 3-5 related sources that provide context\n")
|
||||
sb.WriteString("3. Look for: related news, background, analysis, reactions\n")
|
||||
sb.WriteString("4. Use news categories: `news`, `science` engines\n")
|
||||
sb.WriteString("5. Max 5 additional sources (article itself is [1])\n\n")
|
||||
sb.WriteString("**Search strategy:**\n")
|
||||
sb.WriteString("- Extract key entities/topics from article title\n")
|
||||
sb.WriteString("- Search for recent news on those topics\n")
|
||||
sb.WriteString("- Find expert opinions or analysis\n")
|
||||
sb.WriteString("- Look for official statements if relevant\n\n")
|
||||
}
|
||||
|
||||
if cfg.DetectedLanguage == "ru" {
|
||||
sb.WriteString("## Language\n\n")
|
||||
sb.WriteString("Пользователь пишет на русском. Формулируй поисковые запросы на русском языке.\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString("## Instructions\n\n")
|
||||
sb.WriteString("1. Analyze the user's query and conversation history\n")
|
||||
sb.WriteString("2. Plan what information you need to gather\n")
|
||||
sb.WriteString("3. Execute actions to gather that information\n")
|
||||
sb.WriteString("4. Call 'done' when you have sufficient information\n\n")
|
||||
|
||||
sb.WriteString("## Important Rules\n\n")
|
||||
sb.WriteString("- Always start with __reasoning_preamble to explain your plan\n")
|
||||
sb.WriteString("- Formulate specific, targeted search queries\n")
|
||||
sb.WriteString("- Avoid redundant searches\n")
|
||||
sb.WriteString("- Call 'done' when information is sufficient\n")
|
||||
sb.WriteString("- Don't exceed the iteration limit\n\n")
|
||||
|
||||
sb.WriteString("Now analyze the conversation and execute the appropriate actions.")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GetAvailableActionsDescription() string {
|
||||
return strings.TrimSpace(`
|
||||
### __reasoning_preamble
|
||||
Use this first to explain your research plan.
|
||||
Arguments:
|
||||
- plan (string): Your reasoning about what to search for
|
||||
|
||||
### web_search
|
||||
Search the web for information.
|
||||
Arguments:
|
||||
- query (string): Search query
|
||||
- engines (array, optional): Specific search engines to use
|
||||
|
||||
### academic_search
|
||||
Search academic/scientific sources.
|
||||
Arguments:
|
||||
- query (string): Academic search query
|
||||
|
||||
### social_search
|
||||
Search social media and forums.
|
||||
Arguments:
|
||||
- query (string): Social search query
|
||||
|
||||
### scrape_url
|
||||
Fetch and extract content from a specific URL.
|
||||
Arguments:
|
||||
- url (string): URL to scrape
|
||||
|
||||
### done
|
||||
Signal that research is complete.
|
||||
Arguments:
|
||||
- reason (string): Why research is sufficient
|
||||
`)
|
||||
}
|
||||
146
backend/internal/prompts/writer.go
Normal file
146
backend/internal/prompts/writer.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WriterConfig struct {
|
||||
Context string
|
||||
SystemInstructions string
|
||||
Mode string
|
||||
Locale string
|
||||
MemoryContext string
|
||||
AnswerMode string
|
||||
ResponsePrefs *ResponsePrefs
|
||||
DetectedLanguage string
|
||||
IsArticleSummary bool
|
||||
LearningMode bool
|
||||
}
|
||||
|
||||
type ResponsePrefs struct {
|
||||
Format string
|
||||
Length string
|
||||
Tone string
|
||||
}
|
||||
|
||||
func GetWriterPrompt(cfg WriterConfig) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("You are GooSeek, an AI-powered search assistant similar to Perplexity AI.\n\n")
|
||||
|
||||
if cfg.DetectedLanguage == "ru" {
|
||||
sb.WriteString("ВАЖНО: Пользователь пишет на русском языке. Отвечай ТОЛЬКО на русском языке.\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString("## Core Instructions\n\n")
|
||||
sb.WriteString("1. **Always cite sources** using [number] format, e.g., [1], [2]. Citations must reference the search results provided.\n")
|
||||
sb.WriteString("2. **Be comprehensive** but concise. Provide thorough answers with key information.\n")
|
||||
sb.WriteString("3. **Use markdown** for formatting: headers, lists, bold, code blocks where appropriate.\n")
|
||||
sb.WriteString("4. **Be objective** and factual. Present information neutrally.\n")
|
||||
sb.WriteString("5. **Acknowledge limitations** if search results are insufficient.\n\n")
|
||||
|
||||
if cfg.IsArticleSummary {
|
||||
sb.WriteString("## Article Summary Mode (Perplexity-style Digest)\n\n")
|
||||
sb.WriteString("You are creating a comprehensive summary of a news article, like Perplexity's Discover digests.\n\n")
|
||||
sb.WriteString("**Structure your response as:**\n")
|
||||
sb.WriteString("1. **Headline summary** (1-2 sentences capturing the essence)\n")
|
||||
sb.WriteString("2. **Key points** with citations [1], [2], etc.\n")
|
||||
sb.WriteString("3. **Context and background** from related sources\n")
|
||||
sb.WriteString("4. **Analysis/implications** if relevant\n")
|
||||
sb.WriteString("5. **Related questions** the reader might have (as > quoted lines)\n\n")
|
||||
sb.WriteString("**Rules:**\n")
|
||||
sb.WriteString("- Always cite sources [1], [2], etc.\n")
|
||||
sb.WriteString("- First source [1] is usually the main article\n")
|
||||
sb.WriteString("- Add context from other sources [2], [3], etc.\n")
|
||||
sb.WriteString("- End with 2-3 follow-up questions prefixed with >\n")
|
||||
sb.WriteString("- Write in the user's language (Russian if they use Russian)\n\n")
|
||||
}
|
||||
|
||||
switch cfg.Mode {
|
||||
case "speed":
|
||||
sb.WriteString("## Speed Mode\n\n")
|
||||
sb.WriteString("Provide a quick, focused answer. Be concise (2-3 paragraphs max).\n")
|
||||
sb.WriteString("Prioritize the most relevant information.\n\n")
|
||||
case "balanced":
|
||||
sb.WriteString("## Balanced Mode\n\n")
|
||||
sb.WriteString("Provide a well-rounded answer with moderate detail.\n")
|
||||
sb.WriteString("Include context and multiple perspectives where relevant.\n\n")
|
||||
case "quality":
|
||||
sb.WriteString("## Quality Mode\n\n")
|
||||
sb.WriteString("Provide a comprehensive, in-depth analysis.\n")
|
||||
sb.WriteString("Include detailed explanations, examples, and nuances.\n")
|
||||
sb.WriteString("Cover multiple aspects of the topic.\n\n")
|
||||
}
|
||||
|
||||
if cfg.AnswerMode != "" && cfg.AnswerMode != "standard" {
|
||||
sb.WriteString(fmt.Sprintf("## Answer Mode: %s\n\n", cfg.AnswerMode))
|
||||
sb.WriteString(getAnswerModeInstructions(cfg.AnswerMode))
|
||||
}
|
||||
|
||||
if cfg.ResponsePrefs != nil {
|
||||
sb.WriteString("## Response Preferences\n\n")
|
||||
if cfg.ResponsePrefs.Format != "" {
|
||||
sb.WriteString(fmt.Sprintf("- Format: %s\n", cfg.ResponsePrefs.Format))
|
||||
}
|
||||
if cfg.ResponsePrefs.Length != "" {
|
||||
sb.WriteString(fmt.Sprintf("- Length: %s\n", cfg.ResponsePrefs.Length))
|
||||
}
|
||||
if cfg.ResponsePrefs.Tone != "" {
|
||||
sb.WriteString(fmt.Sprintf("- Tone: %s\n", cfg.ResponsePrefs.Tone))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if cfg.MemoryContext != "" {
|
||||
sb.WriteString("## User Context (from memory)\n\n")
|
||||
sb.WriteString(cfg.MemoryContext)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if cfg.SystemInstructions != "" && cfg.SystemInstructions != "None" {
|
||||
sb.WriteString("## Custom Instructions\n\n")
|
||||
sb.WriteString(cfg.SystemInstructions)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if cfg.LearningMode {
|
||||
sb.WriteString("## Learning Mode\n\n")
|
||||
sb.WriteString("The user is in learning mode. Explain concepts thoroughly.\n")
|
||||
sb.WriteString("Use analogies, examples, and break down complex topics.\n")
|
||||
sb.WriteString("Ask clarifying questions if the topic is ambiguous.\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString("## Search Results\n\n")
|
||||
sb.WriteString(cfg.Context)
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
sb.WriteString("## Citation Rules\n\n")
|
||||
sb.WriteString("- Use [1], [2], etc. to cite sources from the search results\n")
|
||||
sb.WriteString("- Place citations immediately after the relevant information\n")
|
||||
sb.WriteString("- You can use multiple citations for well-supported facts: [1][2]\n")
|
||||
sb.WriteString("- Do not cite widgets or generated content\n")
|
||||
sb.WriteString("- If no relevant source exists, don't make up citations\n\n")
|
||||
|
||||
sb.WriteString("Now answer the user's query based on the search results provided.")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func getAnswerModeInstructions(mode string) string {
|
||||
instructions := map[string]string{
|
||||
"academic": "Focus on scholarly sources, research papers, and academic perspectives. Use formal language and cite peer-reviewed sources when available.\n\n",
|
||||
"writing": "Help with writing tasks. Provide suggestions for structure, style, and content. Be creative and helpful.\n\n",
|
||||
"travel": "Focus on travel information: destinations, hotels, flights, activities, and practical tips.\n\n",
|
||||
"finance": "Provide financial information carefully. Include disclaimers about not being financial advice. Focus on factual data.\n\n",
|
||||
"health": "Provide health information from reliable sources. Always recommend consulting healthcare professionals. Be cautious and accurate.\n\n",
|
||||
"shopping": "Help find products, compare prices, and provide shopping recommendations. Include product features and user reviews.\n\n",
|
||||
"news": "Focus on current events and recent news. Provide multiple perspectives and fact-check information.\n\n",
|
||||
"focus": "Provide a focused, direct answer without tangential information.\n\n",
|
||||
}
|
||||
|
||||
if inst, ok := instructions[mode]; ok {
|
||||
return inst
|
||||
}
|
||||
return ""
|
||||
}
|
||||
215
backend/internal/search/media.go
Normal file
215
backend/internal/search/media.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
)
|
||||
|
||||
type MediaSearchOptions struct {
|
||||
MaxImages int
|
||||
MaxVideos int
|
||||
}
|
||||
|
||||
type MediaSearchResult struct {
|
||||
Images []types.ImageData `json:"images"`
|
||||
Videos []types.VideoData `json:"videos"`
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) SearchMedia(ctx context.Context, query string, opts *MediaSearchOptions) (*MediaSearchResult, error) {
|
||||
if opts == nil {
|
||||
opts = &MediaSearchOptions{MaxImages: 8, MaxVideos: 6}
|
||||
}
|
||||
|
||||
result := &MediaSearchResult{
|
||||
Images: make([]types.ImageData, 0),
|
||||
Videos: make([]types.VideoData, 0),
|
||||
}
|
||||
|
||||
imageCh := make(chan []types.ImageData, 1)
|
||||
videoCh := make(chan []types.VideoData, 1)
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
images, err := c.searchImages(ctx, query, opts.MaxImages)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
imageCh <- nil
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
imageCh <- images
|
||||
}()
|
||||
|
||||
go func() {
|
||||
videos, err := c.searchVideos(ctx, query, opts.MaxVideos)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
videoCh <- nil
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
videoCh <- videos
|
||||
}()
|
||||
|
||||
<-errCh
|
||||
<-errCh
|
||||
result.Images = <-imageCh
|
||||
result.Videos = <-videoCh
|
||||
|
||||
if result.Images == nil {
|
||||
result.Images = make([]types.ImageData, 0)
|
||||
}
|
||||
if result.Videos == nil {
|
||||
result.Videos = make([]types.VideoData, 0)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) searchImages(ctx context.Context, query string, max int) ([]types.ImageData, error) {
|
||||
resp, err := c.Search(ctx, query, &SearchOptions{
|
||||
Categories: []string{"images"},
|
||||
PageNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
images := make([]types.ImageData, 0, max)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, r := range resp.Results {
|
||||
if len(images) >= max {
|
||||
break
|
||||
}
|
||||
|
||||
imgURL := r.ImgSrc
|
||||
if imgURL == "" {
|
||||
imgURL = r.ThumbnailSrc
|
||||
}
|
||||
if imgURL == "" {
|
||||
imgURL = r.Thumbnail
|
||||
}
|
||||
if imgURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if seen[imgURL] {
|
||||
continue
|
||||
}
|
||||
seen[imgURL] = true
|
||||
|
||||
images = append(images, types.ImageData{
|
||||
URL: imgURL,
|
||||
Title: r.Title,
|
||||
Source: extractDomain(r.URL),
|
||||
SourceURL: r.URL,
|
||||
})
|
||||
}
|
||||
|
||||
return images, nil
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) searchVideos(ctx context.Context, query string, max int) ([]types.VideoData, error) {
|
||||
resp, err := c.Search(ctx, query, &SearchOptions{
|
||||
Categories: []string{"videos"},
|
||||
PageNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
videos := make([]types.VideoData, 0, max)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, r := range resp.Results {
|
||||
if len(videos) >= max {
|
||||
break
|
||||
}
|
||||
|
||||
if seen[r.URL] {
|
||||
continue
|
||||
}
|
||||
seen[r.URL] = true
|
||||
|
||||
platform := detectVideoPlatform(r.URL)
|
||||
|
||||
video := types.VideoData{
|
||||
Title: r.Title,
|
||||
URL: r.URL,
|
||||
Thumbnail: r.Thumbnail,
|
||||
Duration: toInt(r.Duration),
|
||||
Views: toInt(r.Views),
|
||||
Author: r.Author,
|
||||
Platform: platform,
|
||||
EmbedURL: r.IframeSrc,
|
||||
}
|
||||
|
||||
videos = append(videos, video)
|
||||
}
|
||||
|
||||
return videos, nil
|
||||
}
|
||||
|
||||
var (
|
||||
youtubePattern = regexp.MustCompile(`youtube\.com|youtu\.be`)
|
||||
rutubePattern = regexp.MustCompile(`rutube\.ru`)
|
||||
vkPattern = regexp.MustCompile(`vk\.com`)
|
||||
dzenPattern = regexp.MustCompile(`dzen\.ru`)
|
||||
)
|
||||
|
||||
func detectVideoPlatform(url string) string {
|
||||
urlLower := strings.ToLower(url)
|
||||
|
||||
if youtubePattern.MatchString(urlLower) {
|
||||
return "youtube"
|
||||
}
|
||||
if rutubePattern.MatchString(urlLower) {
|
||||
return "rutube"
|
||||
}
|
||||
if vkPattern.MatchString(urlLower) {
|
||||
return "vk"
|
||||
}
|
||||
if dzenPattern.MatchString(urlLower) {
|
||||
return "dzen"
|
||||
}
|
||||
|
||||
return "other"
|
||||
}
|
||||
|
||||
func extractDomain(rawURL string) string {
|
||||
rawURL = strings.TrimPrefix(rawURL, "https://")
|
||||
rawURL = strings.TrimPrefix(rawURL, "http://")
|
||||
rawURL = strings.TrimPrefix(rawURL, "www.")
|
||||
|
||||
if idx := strings.Index(rawURL, "/"); idx > 0 {
|
||||
rawURL = rawURL[:idx]
|
||||
}
|
||||
|
||||
return rawURL
|
||||
}
|
||||
|
||||
func toInt(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
163
backend/internal/search/reranker.go
Normal file
163
backend/internal/search/reranker.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
)
|
||||
|
||||
type RankedItem struct {
|
||||
Chunk types.Chunk
|
||||
Score float64
|
||||
}
|
||||
|
||||
func RerankBM25(chunks []types.Chunk, query string, topK int) []types.Chunk {
|
||||
if len(chunks) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
queryTerms := tokenize(query)
|
||||
if len(queryTerms) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
df := make(map[string]int)
|
||||
for _, chunk := range chunks {
|
||||
seen := make(map[string]bool)
|
||||
terms := tokenize(chunk.Content + " " + chunk.Metadata["title"])
|
||||
for _, term := range terms {
|
||||
if !seen[term] {
|
||||
df[term]++
|
||||
seen[term] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
avgDocLen := 0.0
|
||||
for _, chunk := range chunks {
|
||||
avgDocLen += float64(len(tokenize(chunk.Content)))
|
||||
}
|
||||
avgDocLen /= float64(len(chunks))
|
||||
|
||||
k1 := 1.5
|
||||
b := 0.75
|
||||
n := float64(len(chunks))
|
||||
|
||||
ranked := make([]RankedItem, len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
docTerms := tokenize(chunk.Content + " " + chunk.Metadata["title"])
|
||||
docLen := float64(len(docTerms))
|
||||
|
||||
tf := make(map[string]int)
|
||||
for _, term := range docTerms {
|
||||
tf[term]++
|
||||
}
|
||||
|
||||
score := 0.0
|
||||
for _, qterm := range queryTerms {
|
||||
if termFreq, ok := tf[qterm]; ok {
|
||||
docFreq := float64(df[qterm])
|
||||
idf := math.Log((n - docFreq + 0.5) / (docFreq + 0.5))
|
||||
if idf < 0 {
|
||||
idf = 0
|
||||
}
|
||||
|
||||
tfNorm := float64(termFreq) * (k1 + 1) /
|
||||
(float64(termFreq) + k1*(1-b+b*docLen/avgDocLen))
|
||||
|
||||
score += idf * tfNorm
|
||||
}
|
||||
}
|
||||
|
||||
if title, ok := chunk.Metadata["title"]; ok {
|
||||
titleLower := strings.ToLower(title)
|
||||
for _, qterm := range queryTerms {
|
||||
if strings.Contains(titleLower, qterm) {
|
||||
score += 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ranked[i] = RankedItem{Chunk: chunk, Score: score}
|
||||
}
|
||||
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return ranked[i].Score > ranked[j].Score
|
||||
})
|
||||
|
||||
if topK > len(ranked) {
|
||||
topK = len(ranked)
|
||||
}
|
||||
|
||||
result := make([]types.Chunk, topK)
|
||||
for i := 0; i < topK; i++ {
|
||||
result[i] = ranked[i].Chunk
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func tokenize(text string) []string {
|
||||
text = strings.ToLower(text)
|
||||
|
||||
var tokens []string
|
||||
var current strings.Builder
|
||||
|
||||
for _, r := range text {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
current.WriteRune(r)
|
||||
} else {
|
||||
if current.Len() >= 2 {
|
||||
tokens = append(tokens, current.String())
|
||||
}
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() >= 2 {
|
||||
tokens = append(tokens, current.String())
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
func EstimateQueryComplexity(query string) float64 {
|
||||
terms := tokenize(query)
|
||||
complexity := float64(len(terms)) / 5.0
|
||||
|
||||
if strings.Contains(query, "?") {
|
||||
complexity += 0.2
|
||||
}
|
||||
if strings.Contains(query, " и ") || strings.Contains(query, " или ") {
|
||||
complexity += 0.3
|
||||
}
|
||||
|
||||
if complexity > 1.0 {
|
||||
complexity = 1.0
|
||||
}
|
||||
return complexity
|
||||
}
|
||||
|
||||
func ComputeAdaptiveTopK(totalResults int, complexity float64, mode string) int {
|
||||
baseK := 15
|
||||
|
||||
switch mode {
|
||||
case "speed":
|
||||
baseK = 10
|
||||
case "balanced":
|
||||
baseK = 20
|
||||
case "quality":
|
||||
baseK = 30
|
||||
}
|
||||
|
||||
adaptiveK := int(float64(baseK) * (1 + complexity*0.5))
|
||||
|
||||
if adaptiveK > totalResults {
|
||||
adaptiveK = totalResults
|
||||
}
|
||||
|
||||
return adaptiveK
|
||||
}
|
||||
177
backend/internal/search/searxng.go
Normal file
177
backend/internal/search/searxng.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/gooseek/backend/pkg/config"
|
||||
)
|
||||
|
||||
type SearXNGClient struct {
|
||||
primaryURL string
|
||||
fallbackURLs []string
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewSearXNGClient(cfg *config.Config) *SearXNGClient {
|
||||
return &SearXNGClient{
|
||||
primaryURL: cfg.SearXNGURL,
|
||||
fallbackURLs: cfg.SearXNGFallbackURL,
|
||||
client: &http.Client{Timeout: cfg.SearchTimeout},
|
||||
timeout: cfg.SearchTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
type SearchOptions struct {
|
||||
Engines []string
|
||||
Categories []string
|
||||
PageNo int
|
||||
Language string
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) Search(ctx context.Context, query string, opts *SearchOptions) (*types.SearchResponse, error) {
|
||||
candidates := c.buildCandidates()
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("no SearXNG URLs configured")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, baseURL := range candidates {
|
||||
result, err := c.searchWithURL(ctx, baseURL, query, opts)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all SearXNG instances failed: %w", lastErr)
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) buildCandidates() []string {
|
||||
candidates := make([]string, 0)
|
||||
|
||||
if c.primaryURL != "" {
|
||||
u := strings.TrimSuffix(c.primaryURL, "/")
|
||||
if !strings.HasPrefix(u, "http") {
|
||||
u = "http://" + u
|
||||
}
|
||||
candidates = append(candidates, u)
|
||||
}
|
||||
|
||||
for _, fb := range c.fallbackURLs {
|
||||
u := strings.TrimSpace(fb)
|
||||
if u == "" {
|
||||
continue
|
||||
}
|
||||
u = strings.TrimSuffix(u, "/")
|
||||
if !strings.HasPrefix(u, "http") {
|
||||
u = "https://" + u
|
||||
}
|
||||
if !contains(candidates, u) {
|
||||
candidates = append(candidates, u)
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (c *SearXNGClient) searchWithURL(ctx context.Context, baseURL, query string, opts *SearchOptions) (*types.SearchResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("format", "json")
|
||||
params.Set("q", query)
|
||||
|
||||
if opts != nil {
|
||||
if len(opts.Engines) > 0 {
|
||||
params.Set("engines", strings.Join(opts.Engines, ","))
|
||||
}
|
||||
if len(opts.Categories) > 0 {
|
||||
params.Set("categories", strings.Join(opts.Categories, ","))
|
||||
}
|
||||
if opts.PageNo > 0 {
|
||||
params.Set("pageno", fmt.Sprintf("%d", opts.PageNo))
|
||||
}
|
||||
if opts.Language != "" {
|
||||
params.Set("language", opts.Language)
|
||||
}
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/search?%s", baseURL, params.Encode())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Results []types.SearchResult `json:"results"`
|
||||
Suggestions []string `json:"suggestions"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &types.SearchResponse{
|
||||
Results: result.Results,
|
||||
Suggestions: result.Suggestions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
productPattern = regexp.MustCompile(`ozon\.ru/product|wildberries\.ru/catalog/\d|aliexpress\.(ru|com)/item|market\.yandex`)
|
||||
videoPattern = regexp.MustCompile(`rutube\.ru/video|vk\.com/video|vk\.com/clip|youtube\.com/watch|youtu\.be|dzen\.ru/video`)
|
||||
vkProfilePattern = regexp.MustCompile(`vk\.com/[a-zA-Z0-9_.]+$`)
|
||||
tgProfilePattern = regexp.MustCompile(`t\.me/[a-zA-Z0-9_]+$`)
|
||||
)
|
||||
|
||||
func CategorizeResult(result *types.SearchResult) types.ContentCategory {
|
||||
urlLower := strings.ToLower(result.URL)
|
||||
|
||||
if productPattern.MatchString(urlLower) {
|
||||
return types.CategoryProduct
|
||||
}
|
||||
|
||||
if videoPattern.MatchString(urlLower) || result.IframeSrc != "" || result.Category == "videos" {
|
||||
return types.CategoryVideo
|
||||
}
|
||||
|
||||
if tgProfilePattern.MatchString(urlLower) {
|
||||
return types.CategoryProfile
|
||||
}
|
||||
if vkProfilePattern.MatchString(urlLower) && !videoPattern.MatchString(urlLower) {
|
||||
return types.CategoryProfile
|
||||
}
|
||||
|
||||
if result.ImgSrc != "" && result.Category == "images" {
|
||||
return types.CategoryImage
|
||||
}
|
||||
|
||||
return types.CategoryArticle
|
||||
}
|
||||
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
183
backend/internal/session/manager.go
Normal file
183
backend/internal/session/manager.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/gooseek/backend/internal/types"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventData EventType = "data"
|
||||
EventEnd EventType = "end"
|
||||
EventError EventType = "error"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type EventType `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type Subscriber func(event EventType, data interface{})
|
||||
|
||||
type Session struct {
|
||||
id string
|
||||
blocks map[string]*types.Block
|
||||
subscribers []Subscriber
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewSession() *Session {
|
||||
return &Session{
|
||||
id: uuid.New().String(),
|
||||
blocks: make(map[string]*types.Block),
|
||||
subscribers: make([]Subscriber, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) ID() string {
|
||||
return s.id
|
||||
}
|
||||
|
||||
func (s *Session) Subscribe(fn Subscriber) func() {
|
||||
s.mu.Lock()
|
||||
s.subscribers = append(s.subscribers, fn)
|
||||
idx := len(s.subscribers) - 1
|
||||
s.mu.Unlock()
|
||||
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if idx < len(s.subscribers) {
|
||||
s.subscribers = append(s.subscribers[:idx], s.subscribers[idx+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Emit(eventType EventType, data interface{}) {
|
||||
s.mu.RLock()
|
||||
if s.closed {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
subs := make([]Subscriber, len(s.subscribers))
|
||||
copy(subs, s.subscribers)
|
||||
s.mu.RUnlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
sub(eventType, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) EmitBlock(block *types.Block) {
|
||||
s.mu.Lock()
|
||||
s.blocks[block.ID] = block
|
||||
s.mu.Unlock()
|
||||
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "block",
|
||||
"block": block,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) UpdateBlock(blockID string, patches []Patch) {
|
||||
s.mu.Lock()
|
||||
block, ok := s.blocks[blockID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
for _, patch := range patches {
|
||||
applyPatch(block, patch)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "updateBlock",
|
||||
"blockId": blockID,
|
||||
"patch": patches,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) EmitTextChunk(blockID, chunk string) {
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "textChunk",
|
||||
"blockId": blockID,
|
||||
"chunk": chunk,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) EmitResearchComplete() {
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "researchComplete",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) EmitEnd() {
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "messageEnd",
|
||||
})
|
||||
s.Emit(EventEnd, nil)
|
||||
}
|
||||
|
||||
func (s *Session) EmitError(err error) {
|
||||
s.Emit(EventData, map[string]interface{}{
|
||||
"type": "error",
|
||||
"data": err.Error(),
|
||||
})
|
||||
s.Emit(EventError, map[string]interface{}{
|
||||
"data": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) GetBlock(id string) *types.Block {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.blocks[id]
|
||||
}
|
||||
|
||||
func (s *Session) Close() {
|
||||
s.mu.Lock()
|
||||
s.closed = true
|
||||
s.subscribers = nil
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Session) RemoveAllListeners() {
|
||||
s.mu.Lock()
|
||||
s.subscribers = nil
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
type Patch struct {
|
||||
Op string `json:"op"`
|
||||
Path string `json:"path"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
func applyPatch(block *types.Block, patch Patch) {
|
||||
if patch.Op != "replace" {
|
||||
return
|
||||
}
|
||||
|
||||
switch patch.Path {
|
||||
case "/data":
|
||||
block.Data = patch.Value
|
||||
case "/data/subSteps":
|
||||
if rd, ok := block.Data.(types.ResearchData); ok {
|
||||
if steps, ok := patch.Value.([]types.ResearchSubStep); ok {
|
||||
rd.SubSteps = steps
|
||||
block.Data = rd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MarshalEvent(data interface{}) ([]byte, error) {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
102
backend/internal/types/blocks.go
Normal file
102
backend/internal/types/blocks.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package types
|
||||
|
||||
type BlockType string
|
||||
|
||||
const (
|
||||
BlockTypeText BlockType = "text"
|
||||
BlockTypeResearch BlockType = "research"
|
||||
BlockTypeSource BlockType = "source"
|
||||
BlockTypeWidget BlockType = "widget"
|
||||
BlockTypeThinking BlockType = "thinking"
|
||||
)
|
||||
|
||||
type Block struct {
|
||||
ID string `json:"id"`
|
||||
Type BlockType `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type TextBlock struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type ResearchBlock struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Data ResearchData `json:"data"`
|
||||
}
|
||||
|
||||
type ResearchData struct {
|
||||
SubSteps []ResearchSubStep `json:"subSteps"`
|
||||
}
|
||||
|
||||
type ResearchSubStep struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
Searching []string `json:"searching,omitempty"`
|
||||
Reading []Chunk `json:"reading,omitempty"`
|
||||
}
|
||||
|
||||
type SourceBlock struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Data []Chunk `json:"data"`
|
||||
}
|
||||
|
||||
type WidgetBlock struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Data WidgetData `json:"data"`
|
||||
}
|
||||
|
||||
type WidgetData struct {
|
||||
WidgetType string `json:"widgetType"`
|
||||
Params interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Block *Block `json:"block,omitempty"`
|
||||
BlockID string `json:"blockId,omitempty"`
|
||||
Chunk string `json:"chunk,omitempty"`
|
||||
Patch interface{} `json:"patch,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func NewTextBlock(id, content string) *Block {
|
||||
return &Block{
|
||||
ID: id,
|
||||
Type: BlockTypeText,
|
||||
Data: content,
|
||||
}
|
||||
}
|
||||
|
||||
func NewResearchBlock(id string) *Block {
|
||||
return &Block{
|
||||
ID: id,
|
||||
Type: BlockTypeResearch,
|
||||
Data: ResearchData{SubSteps: []ResearchSubStep{}},
|
||||
}
|
||||
}
|
||||
|
||||
func NewSourceBlock(id string, chunks []Chunk) *Block {
|
||||
return &Block{
|
||||
ID: id,
|
||||
Type: BlockTypeSource,
|
||||
Data: chunks,
|
||||
}
|
||||
}
|
||||
|
||||
func NewWidgetBlock(id, widgetType string, params interface{}) *Block {
|
||||
return &Block{
|
||||
ID: id,
|
||||
Type: BlockTypeWidget,
|
||||
Data: WidgetData{
|
||||
WidgetType: widgetType,
|
||||
Params: params,
|
||||
},
|
||||
}
|
||||
}
|
||||
75
backend/internal/types/chunks.go
Normal file
75
backend/internal/types/chunks.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package types
|
||||
|
||||
type Chunk struct {
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Thumbnail string `json:"thumbnail,omitempty"`
|
||||
ImgSrc string `json:"img_src,omitempty"`
|
||||
ThumbnailSrc string `json:"thumbnail_src,omitempty"`
|
||||
IframeSrc string `json:"iframe_src,omitempty"`
|
||||
Author string `json:"author,omitempty"`
|
||||
PublishedDate string `json:"publishedDate,omitempty"`
|
||||
Engine string `json:"engine,omitempty"`
|
||||
Category string `json:"category,omitempty"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Price string `json:"price,omitempty"`
|
||||
Currency string `json:"currency,omitempty"`
|
||||
Duration interface{} `json:"duration,omitempty"`
|
||||
Views interface{} `json:"views,omitempty"`
|
||||
}
|
||||
|
||||
type SearchResponse struct {
|
||||
Results []SearchResult `json:"results"`
|
||||
Suggestions []string `json:"suggestions,omitempty"`
|
||||
}
|
||||
|
||||
type ContentCategory string
|
||||
|
||||
const (
|
||||
CategoryProduct ContentCategory = "product"
|
||||
CategoryVideo ContentCategory = "video"
|
||||
CategoryProfile ContentCategory = "profile"
|
||||
CategoryPromo ContentCategory = "promo"
|
||||
CategoryImage ContentCategory = "image"
|
||||
CategoryArticle ContentCategory = "article"
|
||||
)
|
||||
|
||||
func (r *SearchResult) ToChunk() Chunk {
|
||||
metadata := map[string]string{
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
}
|
||||
if r.Thumbnail != "" {
|
||||
metadata["thumbnail"] = r.Thumbnail
|
||||
}
|
||||
if r.Author != "" {
|
||||
metadata["author"] = r.Author
|
||||
}
|
||||
if r.PublishedDate != "" {
|
||||
metadata["publishedDate"] = r.PublishedDate
|
||||
}
|
||||
|
||||
content := r.Content
|
||||
if content == "" {
|
||||
content = r.Title
|
||||
}
|
||||
|
||||
return Chunk{
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
}
|
||||
}
|
||||
|
||||
func SearchResultsToChunks(results []SearchResult) []Chunk {
|
||||
chunks := make([]Chunk, 0, len(results))
|
||||
for _, r := range results {
|
||||
chunks = append(chunks, r.ToChunk())
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
145
backend/internal/types/widgets.go
Normal file
145
backend/internal/types/widgets.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package types
|
||||
|
||||
type WidgetType string
|
||||
|
||||
const (
|
||||
WidgetWeather WidgetType = "weather"
|
||||
WidgetCalculator WidgetType = "calculator"
|
||||
WidgetProducts WidgetType = "products"
|
||||
WidgetVideos WidgetType = "videos"
|
||||
WidgetProfiles WidgetType = "profiles"
|
||||
WidgetPromos WidgetType = "promos"
|
||||
WidgetImageGallery WidgetType = "image_gallery"
|
||||
WidgetVideoEmbed WidgetType = "video_embed"
|
||||
WidgetKnowledge WidgetType = "knowledge_card"
|
||||
)
|
||||
|
||||
type ProductData struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Price float64 `json:"price"`
|
||||
OldPrice float64 `json:"oldPrice,omitempty"`
|
||||
Currency string `json:"currency"`
|
||||
Discount int `json:"discount,omitempty"`
|
||||
Rating float64 `json:"rating,omitempty"`
|
||||
ReviewCount int `json:"reviewCount,omitempty"`
|
||||
ImageURL string `json:"imageUrl,omitempty"`
|
||||
Marketplace string `json:"marketplace"`
|
||||
InStock bool `json:"inStock"`
|
||||
Badges []Badge `json:"badges,omitempty"`
|
||||
}
|
||||
|
||||
type VideoData struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Thumbnail string `json:"thumbnail,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Views int `json:"views,omitempty"`
|
||||
Likes int `json:"likes,omitempty"`
|
||||
Author string `json:"author,omitempty"`
|
||||
Platform string `json:"platform"`
|
||||
EmbedURL string `json:"embedUrl,omitempty"`
|
||||
}
|
||||
|
||||
type ProfileData struct {
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username,omitempty"`
|
||||
URL string `json:"url"`
|
||||
AvatarURL string `json:"avatarUrl,omitempty"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
Followers int `json:"followers,omitempty"`
|
||||
Following int `json:"following,omitempty"`
|
||||
Platform string `json:"platform"`
|
||||
Verified bool `json:"verified"`
|
||||
IsOnline bool `json:"isOnline,omitempty"`
|
||||
LastOnline string `json:"lastOnline,omitempty"`
|
||||
}
|
||||
|
||||
type PromoData struct {
|
||||
Code string `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Discount string `json:"discount"`
|
||||
Store string `json:"store"`
|
||||
StoreURL string `json:"storeUrl"`
|
||||
LogoURL string `json:"logoUrl,omitempty"`
|
||||
ExpiresAt string `json:"expiresAt,omitempty"`
|
||||
Conditions string `json:"conditions,omitempty"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Source string `json:"source,omitempty"`
|
||||
SourceURL string `json:"sourceUrl,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
}
|
||||
|
||||
type Badge struct {
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
Color string `json:"color,omitempty"`
|
||||
}
|
||||
|
||||
type KnowledgeCardData struct {
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type ComparisonTable struct {
|
||||
Headers []string `json:"headers"`
|
||||
Rows [][]string `json:"rows"`
|
||||
}
|
||||
|
||||
type StatCard struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
Change float64 `json:"change,omitempty"`
|
||||
Unit string `json:"unit,omitempty"`
|
||||
}
|
||||
|
||||
type Timeline struct {
|
||||
Events []TimelineEvent `json:"events"`
|
||||
}
|
||||
|
||||
type TimelineEvent struct {
|
||||
Date string `json:"date"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type WeatherParams struct {
|
||||
Location string `json:"location"`
|
||||
Current WeatherCurrent `json:"current"`
|
||||
Forecast []WeatherDay `json:"forecast,omitempty"`
|
||||
LastUpdated string `json:"lastUpdated,omitempty"`
|
||||
}
|
||||
|
||||
type WeatherCurrent struct {
|
||||
Temp float64 `json:"temp"`
|
||||
FeelsLike float64 `json:"feelsLike"`
|
||||
Humidity int `json:"humidity"`
|
||||
WindSpeed float64 `json:"windSpeed"`
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
type WeatherDay struct {
|
||||
Date string `json:"date"`
|
||||
TempMax float64 `json:"tempMax"`
|
||||
TempMin float64 `json:"tempMin"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
type CalculatorParams struct {
|
||||
Expression string `json:"expression"`
|
||||
Result float64 `json:"result"`
|
||||
Steps []Step `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
type Step struct {
|
||||
Description string `json:"description"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
Reference in New Issue
Block a user