- Add Gitea Actions workflow for automated build & deploy - Add K8s manifests: webui, travel-svc, medicine-svc, sandbox-svc - Update kustomization for localhost:5000 registry - Add ingress for gooseek.ru and api.gooseek.ru - Learning cabinet with onboarding, courses, sandbox integration - Medicine service with symptom analysis and doctor matching - Travel service with itinerary planning - Server setup scripts (NVIDIA/CUDA, K3s, Gitea runner) Made-with: Cursor
963 lines
25 KiB
Go
963 lines
25 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gooseek/backend/internal/llm"
|
|
"github.com/gooseek/backend/internal/prompts"
|
|
"github.com/gooseek/backend/internal/search"
|
|
"github.com/gooseek/backend/internal/session"
|
|
"github.com/gooseek/backend/internal/types"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
type Mode string
|
|
|
|
const (
|
|
ModeSpeed Mode = "speed"
|
|
ModeBalanced Mode = "balanced"
|
|
ModeQuality Mode = "quality"
|
|
)
|
|
|
|
type OrchestratorConfig struct {
|
|
LLM llm.Client
|
|
SearchClient *search.SearXNGClient
|
|
Mode Mode
|
|
FocusMode FocusMode
|
|
Sources []string
|
|
FileIDs []string
|
|
FileContext string
|
|
CollectionID string
|
|
CollectionContext string
|
|
SystemInstructions string
|
|
Locale string
|
|
MemoryContext string
|
|
UserMemory string
|
|
AnswerMode string
|
|
ResponsePrefs *ResponsePrefs
|
|
LearningMode bool
|
|
EnableDeepResearch bool
|
|
EnableClarifying bool
|
|
DiscoverSvcURL string
|
|
Crawl4AIURL string
|
|
CollectionSvcURL string
|
|
FileSvcURL string
|
|
TravelSvcURL string
|
|
TravelPayoutsToken string
|
|
TravelPayoutsMarker string
|
|
PhotoCache *PhotoCacheService
|
|
}
|
|
|
|
type DigestResponse struct {
|
|
SummaryRu string `json:"summaryRu"`
|
|
Citations []DigestCitation `json:"citations"`
|
|
FollowUp []string `json:"followUp"`
|
|
SourcesCount int `json:"sourcesCount"`
|
|
ClusterTitle string `json:"clusterTitle"`
|
|
}
|
|
|
|
type DigestCitation struct {
|
|
Index int `json:"index"`
|
|
URL string `json:"url"`
|
|
Title string `json:"title"`
|
|
Domain string `json:"domain"`
|
|
}
|
|
|
|
type PreScrapedArticle struct {
|
|
Title string
|
|
Content string
|
|
URL string
|
|
}
|
|
|
|
type ResponsePrefs struct {
|
|
Format string `json:"format,omitempty"`
|
|
Length string `json:"length,omitempty"`
|
|
Tone string `json:"tone,omitempty"`
|
|
}
|
|
|
|
type OrchestratorInput struct {
|
|
ChatHistory []llm.Message
|
|
FollowUp string
|
|
Config OrchestratorConfig
|
|
}
|
|
|
|
func RunOrchestrator(ctx context.Context, sess *session.Session, input OrchestratorInput) error {
|
|
if input.Config.AnswerMode == "travel" {
|
|
return RunTravelOrchestrator(ctx, sess, input)
|
|
}
|
|
|
|
if input.Config.AnswerMode == "learning" || input.Config.LearningMode {
|
|
return RunLearningOrchestrator(ctx, sess, input)
|
|
}
|
|
|
|
detectedLang := detectLanguage(input.FollowUp)
|
|
isArticleSummary := strings.HasPrefix(strings.TrimSpace(input.FollowUp), "Summary: ")
|
|
|
|
if input.Config.FocusMode == "" {
|
|
input.Config.FocusMode = DetectFocusMode(input.FollowUp)
|
|
}
|
|
|
|
if input.Config.EnableDeepResearch && input.Config.Mode == ModeQuality {
|
|
return runDeepResearchMode(ctx, sess, input, detectedLang)
|
|
}
|
|
|
|
if input.Config.Mode == ModeSpeed && !isArticleSummary {
|
|
return runSpeedMode(ctx, sess, input, detectedLang)
|
|
}
|
|
|
|
return runFullMode(ctx, sess, input, detectedLang, isArticleSummary)
|
|
}
|
|
|
|
func runDeepResearchMode(ctx context.Context, sess *session.Session, input OrchestratorInput, lang string) error {
|
|
sess.EmitBlock(types.NewResearchBlock(uuid.New().String()))
|
|
|
|
researcher := NewDeepResearcher(DeepResearchConfig{
|
|
LLM: input.Config.LLM,
|
|
SearchClient: input.Config.SearchClient,
|
|
FocusMode: input.Config.FocusMode,
|
|
Locale: input.Config.Locale,
|
|
MaxSearchQueries: 30,
|
|
MaxSources: 100,
|
|
MaxIterations: 5,
|
|
Timeout: 5 * time.Minute,
|
|
}, sess)
|
|
|
|
result, err := researcher.Research(ctx, input.FollowUp)
|
|
if err != nil {
|
|
sess.EmitError(err)
|
|
return err
|
|
}
|
|
|
|
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), result.Sources))
|
|
|
|
if len(result.FollowUpQueries) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
|
|
"questions": result.FollowUpQueries,
|
|
}))
|
|
}
|
|
|
|
sess.EmitResearchComplete()
|
|
sess.EmitEnd()
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateClarifyingQuestions(ctx context.Context, llmClient llm.Client, query string) ([]string, error) {
|
|
prompt := fmt.Sprintf(`Analyze this query and determine if clarifying questions would help provide a better answer.
|
|
|
|
Query: %s
|
|
|
|
If the query is:
|
|
- Clear and specific → respond with "CLEAR"
|
|
- Ambiguous or could benefit from clarification → provide 2-3 short clarifying questions
|
|
|
|
Format:
|
|
CLEAR
|
|
or
|
|
QUESTION: [question 1]
|
|
QUESTION: [question 2]
|
|
QUESTION: [question 3]`, query)
|
|
|
|
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
|
|
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if strings.Contains(strings.ToUpper(result), "CLEAR") {
|
|
return nil, nil
|
|
}
|
|
|
|
var questions []string
|
|
for _, line := range strings.Split(result, "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "QUESTION:") {
|
|
q := strings.TrimSpace(strings.TrimPrefix(line, "QUESTION:"))
|
|
if q != "" {
|
|
questions = append(questions, q)
|
|
}
|
|
}
|
|
}
|
|
|
|
return questions, nil
|
|
}
|
|
|
|
func generateRelatedQuestions(ctx context.Context, llmClient llm.Client, query, answer string, locale string) []string {
|
|
langInstruction := ""
|
|
if locale == "ru" {
|
|
langInstruction = "Generate questions in Russian."
|
|
}
|
|
|
|
prompt := fmt.Sprintf(`Based on this query and answer, generate 3-4 related follow-up questions the user might want to explore.
|
|
|
|
Query: %s
|
|
|
|
Answer summary: %s
|
|
|
|
%s
|
|
Format: One question per line, no numbering or bullets.`, query, truncateForPrompt(answer, 500), langInstruction)
|
|
|
|
result, err := llmClient.GenerateText(ctx, llm.StreamRequest{
|
|
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
|
})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
var questions []string
|
|
for _, line := range strings.Split(result, "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if line != "" && len(line) > 10 && strings.Contains(line, "?") {
|
|
line = strings.TrimLeft(line, "0123456789.-•* ")
|
|
if line != "" {
|
|
questions = append(questions, line)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(questions) > 4 {
|
|
questions = questions[:4]
|
|
}
|
|
|
|
return questions
|
|
}
|
|
|
|
func truncateForPrompt(s string, maxLen int) string {
|
|
if len(s) <= maxLen {
|
|
return s
|
|
}
|
|
return s[:maxLen] + "..."
|
|
}
|
|
|
|
func buildEnhancedContext(input OrchestratorInput) string {
|
|
var ctx strings.Builder
|
|
|
|
if input.Config.UserMemory != "" {
|
|
ctx.WriteString("## User Preferences\n")
|
|
ctx.WriteString(input.Config.UserMemory)
|
|
ctx.WriteString("\n\n")
|
|
}
|
|
|
|
if input.Config.CollectionContext != "" {
|
|
ctx.WriteString("## Collection Context\n")
|
|
ctx.WriteString(input.Config.CollectionContext)
|
|
ctx.WriteString("\n\n")
|
|
}
|
|
|
|
if input.Config.FileContext != "" {
|
|
ctx.WriteString("## Uploaded Files Content\n")
|
|
ctx.WriteString(input.Config.FileContext)
|
|
ctx.WriteString("\n\n")
|
|
}
|
|
|
|
if input.Config.MemoryContext != "" {
|
|
ctx.WriteString("## Previous Context\n")
|
|
ctx.WriteString(input.Config.MemoryContext)
|
|
ctx.WriteString("\n\n")
|
|
}
|
|
|
|
return ctx.String()
|
|
}
|
|
|
|
func fetchPreGeneratedDigest(ctx context.Context, discoverURL, articleURL string) (*DigestResponse, error) {
|
|
if discoverURL == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
reqURL := fmt.Sprintf("%s/api/v1/discover/digest?url=%s",
|
|
strings.TrimSuffix(discoverURL, "/"),
|
|
url.QueryEscape(articleURL))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := &http.Client{Timeout: 3 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, nil
|
|
}
|
|
|
|
var digest DigestResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&digest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if digest.SummaryRu != "" && len(digest.Citations) > 0 {
|
|
return &digest, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func preScrapeArticleURL(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
|
|
if crawl4aiURL != "" {
|
|
article, err := scrapeWithCrawl4AI(ctx, crawl4aiURL, articleURL)
|
|
if err == nil && article != nil {
|
|
return article, nil
|
|
}
|
|
}
|
|
|
|
return scrapeDirectly(ctx, articleURL)
|
|
}
|
|
|
|
func scrapeWithCrawl4AI(ctx context.Context, crawl4aiURL, articleURL string) (*PreScrapedArticle, error) {
|
|
reqBody := fmt.Sprintf(`{
|
|
"urls": ["%s"],
|
|
"crawler_config": {
|
|
"type": "CrawlerRunConfig",
|
|
"params": {
|
|
"cache_mode": "default",
|
|
"page_timeout": 20000
|
|
}
|
|
}
|
|
}`, articleURL)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", crawl4aiURL+"/crawl", strings.NewReader(reqBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 25 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Crawl4AI returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
markdown := extractMarkdownFromCrawl4AI(string(body))
|
|
title := extractTitleFromCrawl4AI(string(body))
|
|
|
|
if len(markdown) > 100 {
|
|
content := markdown
|
|
if len(content) > 15000 {
|
|
content = content[:15000]
|
|
}
|
|
return &PreScrapedArticle{
|
|
Title: title,
|
|
Content: content,
|
|
URL: articleURL,
|
|
}, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("insufficient content from Crawl4AI")
|
|
}
|
|
|
|
func scrapeDirectly(ctx context.Context, articleURL string) (*PreScrapedArticle, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", articleURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("User-Agent", "GooSeek-Agent/1.0")
|
|
req.Header.Set("Accept", "text/html")
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
html := string(body)
|
|
title := extractHTMLTitle(html)
|
|
content := extractTextContent(html)
|
|
|
|
if len(content) < 100 {
|
|
return nil, fmt.Errorf("insufficient content")
|
|
}
|
|
|
|
if len(content) > 15000 {
|
|
content = content[:15000]
|
|
}
|
|
|
|
return &PreScrapedArticle{
|
|
Title: title,
|
|
Content: content,
|
|
URL: articleURL,
|
|
}, nil
|
|
}
|
|
|
|
var (
|
|
titleRegex = regexp.MustCompile(`<title[^>]*>([^<]+)</title>`)
|
|
scriptRegex = regexp.MustCompile(`(?s)<script[^>]*>.*?</script>`)
|
|
styleRegex = regexp.MustCompile(`(?s)<style[^>]*>.*?</style>`)
|
|
tagRegex = regexp.MustCompile(`<[^>]+>`)
|
|
spaceRegex = regexp.MustCompile(`\s+`)
|
|
)
|
|
|
|
func extractHTMLTitle(html string) string {
|
|
matches := titleRegex.FindStringSubmatch(html)
|
|
if len(matches) > 1 {
|
|
return strings.TrimSpace(matches[1])
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func extractTextContent(html string) string {
|
|
bodyStart := strings.Index(strings.ToLower(html), "<body")
|
|
bodyEnd := strings.Index(strings.ToLower(html), "</body>")
|
|
|
|
if bodyStart != -1 && bodyEnd != -1 && bodyEnd > bodyStart {
|
|
html = html[bodyStart:bodyEnd]
|
|
}
|
|
|
|
html = scriptRegex.ReplaceAllString(html, "")
|
|
html = styleRegex.ReplaceAllString(html, "")
|
|
html = tagRegex.ReplaceAllString(html, " ")
|
|
html = spaceRegex.ReplaceAllString(html, " ")
|
|
|
|
return strings.TrimSpace(html)
|
|
}
|
|
|
|
func extractMarkdownFromCrawl4AI(response string) string {
|
|
if idx := strings.Index(response, `"raw_markdown"`); idx != -1 {
|
|
start := idx + len(`"raw_markdown"`)
|
|
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
|
|
start += colonIdx + 1
|
|
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
|
|
start++
|
|
}
|
|
end := strings.Index(response[start:], `"`)
|
|
if end > 0 {
|
|
return response[start : start+end]
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func extractTitleFromCrawl4AI(response string) string {
|
|
if idx := strings.Index(response, `"title"`); idx != -1 {
|
|
start := idx + len(`"title"`)
|
|
if colonIdx := strings.Index(response[start:], ":"); colonIdx != -1 {
|
|
start += colonIdx + 1
|
|
for start < len(response) && (response[start] == ' ' || response[start] == '"') {
|
|
start++
|
|
}
|
|
end := strings.Index(response[start:], `"`)
|
|
if end > 0 {
|
|
return response[start : start+end]
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func runSpeedMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string) error {
|
|
classification := fastClassify(input.FollowUp, input.ChatHistory)
|
|
searchQuery := classification.StandaloneFollowUp
|
|
if searchQuery == "" {
|
|
searchQuery = input.FollowUp
|
|
}
|
|
queries := generateSearchQueries(searchQuery)
|
|
|
|
researchBlockID := uuid.New().String()
|
|
sess.EmitBlock(types.NewResearchBlock(researchBlockID))
|
|
|
|
var searchResults []types.Chunk
|
|
var mediaResult *search.MediaSearchResult
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
|
|
g.Go(func() error {
|
|
results, err := parallelSearch(gctx, input.Config.SearchClient, queries)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
searchResults = results
|
|
return nil
|
|
})
|
|
|
|
g.Go(func() error {
|
|
result, err := input.Config.SearchClient.SearchMedia(gctx, searchQuery, &search.MediaSearchOptions{
|
|
MaxImages: 6,
|
|
MaxVideos: 4,
|
|
})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
mediaResult = result
|
|
return nil
|
|
})
|
|
|
|
_ = g.Wait()
|
|
|
|
if len(searchResults) > 0 {
|
|
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
|
|
}
|
|
|
|
if mediaResult != nil {
|
|
if len(mediaResult.Images) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
|
|
"images": mediaResult.Images,
|
|
"layout": "carousel",
|
|
}))
|
|
}
|
|
if len(mediaResult.Videos) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
|
|
"items": mediaResult.Videos,
|
|
"title": "",
|
|
}))
|
|
}
|
|
}
|
|
|
|
sess.EmitResearchComplete()
|
|
|
|
queryComplexity := search.EstimateQueryComplexity(searchQuery)
|
|
adaptiveTopK := search.ComputeAdaptiveTopK(len(searchResults), queryComplexity, "speed")
|
|
rankedResults := search.RerankBM25(searchResults, searchQuery, adaptiveTopK)
|
|
|
|
finalContext := buildContext(rankedResults, 15, 250)
|
|
|
|
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
|
|
Context: finalContext,
|
|
SystemInstructions: input.Config.SystemInstructions,
|
|
Mode: string(input.Config.Mode),
|
|
Locale: input.Config.Locale,
|
|
MemoryContext: input.Config.MemoryContext,
|
|
AnswerMode: input.Config.AnswerMode,
|
|
DetectedLanguage: detectedLang,
|
|
IsArticleSummary: false,
|
|
})
|
|
|
|
messages := []llm.Message{
|
|
{Role: llm.RoleSystem, Content: writerPrompt},
|
|
}
|
|
messages = append(messages, input.ChatHistory...)
|
|
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
|
|
|
|
return streamResponse(ctx, sess, input.Config.LLM, messages, 2048, input.FollowUp, input.Config.Locale)
|
|
}
|
|
|
|
func runFullMode(ctx context.Context, sess *session.Session, input OrchestratorInput, detectedLang string, isArticleSummary bool) error {
|
|
if input.Config.EnableClarifying && !isArticleSummary && input.Config.Mode == ModeQuality {
|
|
clarifying, err := generateClarifyingQuestions(ctx, input.Config.LLM, input.FollowUp)
|
|
if err == nil && len(clarifying) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "clarifying", map[string]interface{}{
|
|
"questions": clarifying,
|
|
"query": input.FollowUp,
|
|
}))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
enhancedContext := buildEnhancedContext(input)
|
|
if enhancedContext != "" {
|
|
input.Config.MemoryContext = enhancedContext + input.Config.MemoryContext
|
|
}
|
|
|
|
var preScrapedArticle *PreScrapedArticle
|
|
var articleURL string
|
|
|
|
if isArticleSummary {
|
|
articleURL = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(input.FollowUp), "Summary: "))
|
|
|
|
digestCtx, digestCancel := context.WithTimeout(ctx, 3*time.Second)
|
|
scrapeCtx, scrapeCancel := context.WithTimeout(ctx, 25*time.Second)
|
|
|
|
digestCh := make(chan *DigestResponse, 1)
|
|
scrapeCh := make(chan *PreScrapedArticle, 1)
|
|
|
|
go func() {
|
|
defer digestCancel()
|
|
digest, _ := fetchPreGeneratedDigest(digestCtx, input.Config.DiscoverSvcURL, articleURL)
|
|
digestCh <- digest
|
|
}()
|
|
|
|
go func() {
|
|
defer scrapeCancel()
|
|
article, _ := preScrapeArticleURL(scrapeCtx, input.Config.Crawl4AIURL, articleURL)
|
|
scrapeCh <- article
|
|
}()
|
|
|
|
digest := <-digestCh
|
|
preScrapedArticle = <-scrapeCh
|
|
|
|
if digest != nil {
|
|
chunks := make([]types.Chunk, len(digest.Citations))
|
|
for i, c := range digest.Citations {
|
|
chunks[i] = types.Chunk{
|
|
Content: c.Title,
|
|
Metadata: map[string]string{
|
|
"url": c.URL,
|
|
"title": c.Title,
|
|
"domain": c.Domain,
|
|
},
|
|
}
|
|
}
|
|
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), chunks))
|
|
sess.EmitResearchComplete()
|
|
|
|
summaryText := digest.SummaryRu
|
|
if len(digest.FollowUp) > 0 {
|
|
summaryText += "\n\n---\n"
|
|
for _, q := range digest.FollowUp {
|
|
summaryText += "> " + q + "\n"
|
|
}
|
|
}
|
|
sess.EmitBlock(types.NewTextBlock(uuid.New().String(), summaryText))
|
|
sess.EmitEnd()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
classification, err := classify(ctx, input.Config.LLM, input.FollowUp, input.ChatHistory, input.Config.Locale, detectedLang)
|
|
if err != nil {
|
|
classification = &ClassificationResult{
|
|
StandaloneFollowUp: input.FollowUp,
|
|
SkipSearch: false,
|
|
}
|
|
}
|
|
|
|
if isArticleSummary && classification.SkipSearch {
|
|
classification.SkipSearch = false
|
|
}
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
|
|
var searchResults []types.Chunk
|
|
var mediaResult *search.MediaSearchResult
|
|
|
|
mediaQuery := classification.StandaloneFollowUp
|
|
if mediaQuery == "" {
|
|
mediaQuery = input.FollowUp
|
|
}
|
|
|
|
effectiveFollowUp := input.FollowUp
|
|
if isArticleSummary && preScrapedArticle != nil && preScrapedArticle.Title != "" {
|
|
effectiveFollowUp = fmt.Sprintf("Summary: %s\nArticle title: %s", preScrapedArticle.URL, preScrapedArticle.Title)
|
|
if classification.StandaloneFollowUp != "" {
|
|
classification.StandaloneFollowUp = preScrapedArticle.Title + " " + classification.StandaloneFollowUp
|
|
} else {
|
|
classification.StandaloneFollowUp = preScrapedArticle.Title
|
|
}
|
|
}
|
|
|
|
if !classification.SkipSearch {
|
|
g.Go(func() error {
|
|
results, err := research(gctx, sess, input.Config.LLM, input.Config.SearchClient, ResearchInput{
|
|
ChatHistory: input.ChatHistory,
|
|
FollowUp: effectiveFollowUp,
|
|
Classification: classification,
|
|
Mode: input.Config.Mode,
|
|
Sources: input.Config.Sources,
|
|
Locale: input.Config.Locale,
|
|
DetectedLang: detectedLang,
|
|
IsArticleSummary: isArticleSummary,
|
|
})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
searchResults = results
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if !isArticleSummary {
|
|
g.Go(func() error {
|
|
result, err := input.Config.SearchClient.SearchMedia(gctx, mediaQuery, &search.MediaSearchOptions{
|
|
MaxImages: 8,
|
|
MaxVideos: 6,
|
|
})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
mediaResult = result
|
|
return nil
|
|
})
|
|
}
|
|
|
|
_ = g.Wait()
|
|
|
|
if isArticleSummary && preScrapedArticle != nil {
|
|
alreadyHasURL := false
|
|
for _, r := range searchResults {
|
|
if strings.Contains(r.Metadata["url"], preScrapedArticle.URL) {
|
|
alreadyHasURL = true
|
|
break
|
|
}
|
|
}
|
|
if !alreadyHasURL {
|
|
prependChunk := types.Chunk{
|
|
Content: preScrapedArticle.Content,
|
|
Metadata: map[string]string{
|
|
"url": preScrapedArticle.URL,
|
|
"title": preScrapedArticle.Title,
|
|
},
|
|
}
|
|
searchResults = append([]types.Chunk{prependChunk}, searchResults...)
|
|
}
|
|
}
|
|
|
|
if len(searchResults) > 0 {
|
|
sess.EmitBlock(types.NewSourceBlock(uuid.New().String(), searchResults))
|
|
}
|
|
|
|
if mediaResult != nil {
|
|
if len(mediaResult.Images) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "image_gallery", map[string]interface{}{
|
|
"images": mediaResult.Images,
|
|
"layout": "carousel",
|
|
}))
|
|
}
|
|
if len(mediaResult.Videos) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "videos", map[string]interface{}{
|
|
"items": mediaResult.Videos,
|
|
"title": "",
|
|
}))
|
|
}
|
|
}
|
|
|
|
sess.EmitResearchComplete()
|
|
|
|
maxResults := 25
|
|
maxContent := 320
|
|
if isArticleSummary {
|
|
maxResults = 30
|
|
maxContent = 2000
|
|
}
|
|
|
|
rankedResults := rankByRelevance(searchResults, input.FollowUp)
|
|
if len(rankedResults) > maxResults {
|
|
rankedResults = rankedResults[:maxResults]
|
|
}
|
|
|
|
finalContext := buildContext(rankedResults, maxResults, maxContent)
|
|
|
|
writerPrompt := prompts.GetWriterPrompt(prompts.WriterConfig{
|
|
Context: finalContext,
|
|
SystemInstructions: input.Config.SystemInstructions,
|
|
Mode: string(input.Config.Mode),
|
|
Locale: input.Config.Locale,
|
|
MemoryContext: input.Config.MemoryContext,
|
|
AnswerMode: input.Config.AnswerMode,
|
|
DetectedLanguage: detectedLang,
|
|
IsArticleSummary: isArticleSummary,
|
|
})
|
|
|
|
messages := []llm.Message{
|
|
{Role: llm.RoleSystem, Content: writerPrompt},
|
|
}
|
|
messages = append(messages, input.ChatHistory...)
|
|
messages = append(messages, llm.Message{Role: llm.RoleUser, Content: input.FollowUp})
|
|
|
|
maxTokens := 4096
|
|
return streamResponse(ctx, sess, input.Config.LLM, messages, maxTokens, input.FollowUp, input.Config.Locale)
|
|
}
|
|
|
|
func streamResponse(ctx context.Context, sess *session.Session, client llm.Client, messages []llm.Message, maxTokens int, query string, locale string) error {
|
|
stream, err := client.StreamText(ctx, llm.StreamRequest{
|
|
Messages: messages,
|
|
Options: llm.StreamOptions{MaxTokens: maxTokens},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var responseBlockID string
|
|
var accumulatedText string
|
|
|
|
for chunk := range stream {
|
|
if chunk.ContentChunk == "" && responseBlockID == "" {
|
|
continue
|
|
}
|
|
|
|
if responseBlockID == "" {
|
|
responseBlockID = uuid.New().String()
|
|
accumulatedText = chunk.ContentChunk
|
|
sess.EmitBlock(types.NewTextBlock(responseBlockID, accumulatedText))
|
|
} else if chunk.ContentChunk != "" {
|
|
accumulatedText += chunk.ContentChunk
|
|
sess.EmitTextChunk(responseBlockID, chunk.ContentChunk)
|
|
}
|
|
}
|
|
|
|
if responseBlockID != "" {
|
|
sess.UpdateBlock(responseBlockID, []session.Patch{
|
|
{Op: "replace", Path: "/data", Value: accumulatedText},
|
|
})
|
|
}
|
|
|
|
go func() {
|
|
relatedCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
related := generateRelatedQuestions(relatedCtx, client, query, accumulatedText, locale)
|
|
if len(related) > 0 {
|
|
sess.EmitBlock(types.NewWidgetBlock(uuid.New().String(), "related_questions", map[string]interface{}{
|
|
"questions": related,
|
|
}))
|
|
}
|
|
}()
|
|
|
|
sess.EmitEnd()
|
|
return nil
|
|
}
|
|
|
|
func parallelSearch(ctx context.Context, client *search.SearXNGClient, queries []string) ([]types.Chunk, error) {
|
|
results := make([]types.Chunk, 0)
|
|
seen := make(map[string]bool)
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
resultsCh := make(chan []types.SearchResult, len(queries))
|
|
|
|
for _, q := range queries {
|
|
query := q
|
|
g.Go(func() error {
|
|
resp, err := client.Search(gctx, query, &search.SearchOptions{
|
|
Categories: []string{"general", "news"},
|
|
PageNo: 1,
|
|
})
|
|
if err != nil {
|
|
resultsCh <- nil
|
|
return nil
|
|
}
|
|
resultsCh <- resp.Results
|
|
return nil
|
|
})
|
|
}
|
|
|
|
go func() {
|
|
g.Wait()
|
|
close(resultsCh)
|
|
}()
|
|
|
|
for batch := range resultsCh {
|
|
for _, r := range batch {
|
|
if r.URL != "" && !seen[r.URL] {
|
|
seen[r.URL] = true
|
|
results = append(results, r.ToChunk())
|
|
}
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func buildContext(chunks []types.Chunk, maxResults, maxContentLen int) string {
|
|
if len(chunks) > maxResults {
|
|
chunks = chunks[:maxResults]
|
|
}
|
|
|
|
var sb strings.Builder
|
|
sb.WriteString("<search_results note=\"These are the search results and assistant can cite these\">\n")
|
|
|
|
for i, chunk := range chunks {
|
|
content := chunk.Content
|
|
if len(content) > maxContentLen {
|
|
content = content[:maxContentLen] + "…"
|
|
}
|
|
title := chunk.Metadata["title"]
|
|
sb.WriteString("<result index=")
|
|
sb.WriteString(strings.ReplaceAll(title, "\"", "'"))
|
|
sb.WriteString("\" index=\"")
|
|
sb.WriteString(string(rune('0' + i + 1)))
|
|
sb.WriteString("\">")
|
|
sb.WriteString(content)
|
|
sb.WriteString("</result>\n")
|
|
}
|
|
|
|
sb.WriteString("</search_results>")
|
|
return sb.String()
|
|
}
|
|
|
|
func rankByRelevance(chunks []types.Chunk, query string) []types.Chunk {
|
|
if len(chunks) == 0 {
|
|
return chunks
|
|
}
|
|
|
|
terms := extractQueryTerms(query)
|
|
if len(terms) == 0 {
|
|
return chunks
|
|
}
|
|
|
|
type scored struct {
|
|
chunk types.Chunk
|
|
score int
|
|
}
|
|
|
|
scored_chunks := make([]scored, len(chunks))
|
|
for i, chunk := range chunks {
|
|
score := 0
|
|
content := strings.ToLower(chunk.Content)
|
|
title := strings.ToLower(chunk.Metadata["title"])
|
|
|
|
for term := range terms {
|
|
if strings.Contains(title, term) {
|
|
score += 3
|
|
}
|
|
if strings.Contains(content, term) {
|
|
score += 1
|
|
}
|
|
}
|
|
|
|
scored_chunks[i] = scored{chunk: chunk, score: score}
|
|
}
|
|
|
|
for i := 0; i < len(scored_chunks)-1; i++ {
|
|
for j := i + 1; j < len(scored_chunks); j++ {
|
|
if scored_chunks[j].score > scored_chunks[i].score {
|
|
scored_chunks[i], scored_chunks[j] = scored_chunks[j], scored_chunks[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
result := make([]types.Chunk, len(scored_chunks))
|
|
for i, s := range scored_chunks {
|
|
result[i] = s.chunk
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func extractQueryTerms(query string) map[string]bool {
|
|
query = strings.ToLower(query)
|
|
query = strings.TrimPrefix(query, "summary: ")
|
|
|
|
words := strings.Fields(query)
|
|
terms := make(map[string]bool)
|
|
|
|
for _, w := range words {
|
|
if len(w) >= 2 && !strings.HasPrefix(w, "http") {
|
|
terms[w] = true
|
|
}
|
|
}
|
|
|
|
return terms
|
|
}
|