package agent import ( "context" "encoding/json" "regexp" "strings" "github.com/gooseek/backend/internal/llm" "github.com/gooseek/backend/internal/prompts" ) type ClassificationResult struct { StandaloneFollowUp string `json:"standaloneFollowUp"` SkipSearch bool `json:"skipSearch"` Topics []string `json:"topics,omitempty"` QueryType string `json:"queryType,omitempty"` Engines []string `json:"engines,omitempty"` } func classify(ctx context.Context, client llm.Client, query string, history []llm.Message, locale, detectedLang string) (*ClassificationResult, error) { prompt := prompts.GetClassifierPrompt(locale, detectedLang) historyStr := formatHistory(history) userContent := "\n" + historyStr + "\nUser: " + query + "\n" messages := []llm.Message{ {Role: llm.RoleSystem, Content: prompt}, {Role: llm.RoleUser, Content: userContent}, } response, err := client.GenerateText(ctx, llm.StreamRequest{ Messages: messages, Options: llm.StreamOptions{MaxTokens: 1024}, }) if err != nil { return nil, err } jsonMatch := regexp.MustCompile(`\{[\s\S]*\}`).FindString(response) if jsonMatch == "" { return &ClassificationResult{ StandaloneFollowUp: query, SkipSearch: false, }, nil } var result ClassificationResult if err := json.Unmarshal([]byte(jsonMatch), &result); err != nil { return &ClassificationResult{ StandaloneFollowUp: query, SkipSearch: false, }, nil } if result.StandaloneFollowUp == "" { result.StandaloneFollowUp = query } return &result, nil } func fastClassify(query string, history []llm.Message) *ClassificationResult { queryLower := strings.ToLower(query) skipPatterns := []string{ "привет", "как дела", "спасибо", "пока", "hello", "hi", "thanks", "bye", "объясни", "расскажи подробнее", "что ты имеешь", } skipSearch := false for _, p := range skipPatterns { if strings.Contains(queryLower, p) && len(query) < 50 { skipSearch = true break } } standalone := query if len(history) > 0 { pronouns := []string{ "это", "этот", "эта", "эти", "он", "она", "оно", "они", "it", "this", "that", "they", "them", } hasPronouns := false for _, p := range pronouns { if strings.Contains(queryLower, p+" ") || strings.HasPrefix(queryLower, p+" ") { hasPronouns = true break } } if hasPronouns && len(history) >= 2 { lastAssistant := "" for i := len(history) - 1; i >= 0; i-- { if history[i].Role == llm.RoleAssistant { lastAssistant = history[i].Content break } } if lastAssistant != "" { topics := extractTopics(lastAssistant) if len(topics) > 0 { standalone = query + " (контекст: " + strings.Join(topics, ", ") + ")" } } } } engines := detectEngines(queryLower) return &ClassificationResult{ StandaloneFollowUp: standalone, SkipSearch: skipSearch, Engines: engines, } } func generateSearchQueries(query string) []string { queries := []string{query} if len(query) > 100 { words := strings.Fields(query) if len(words) > 5 { queries = append(queries, strings.Join(words[:5], " ")) } } keywordPatterns := []string{ "как", "что такое", "где", "когда", "почему", "кто", "how", "what is", "where", "when", "why", "who", } for _, p := range keywordPatterns { if strings.HasPrefix(strings.ToLower(query), p) { withoutPrefix := strings.TrimPrefix(strings.ToLower(query), p) withoutPrefix = strings.TrimSpace(withoutPrefix) if len(withoutPrefix) > 10 { queries = append(queries, withoutPrefix) } break } } if len(queries) > 3 { queries = queries[:3] } return queries } func detectEngines(query string) []string { engines := []string{"google", "duckduckgo"} if strings.Contains(query, "новости") || strings.Contains(query, "news") { engines = append(engines, "google_news") } if strings.Contains(query, "видео") || strings.Contains(query, "video") { engines = append(engines, "youtube") } if strings.Contains(query, "товар") || strings.Contains(query, "купить") || strings.Contains(query, "цена") || strings.Contains(query, "price") { engines = append(engines, "google_shopping") } return engines } func extractTopics(text string) []string { words := strings.Fields(text) if len(words) > 50 { words = words[:50] } topics := make([]string, 0) for _, w := range words { if len(w) > 5 && len(w) < 20 { r := []rune(w) if len(r) > 0 && ((r[0] >= 'A' && r[0] <= 'Z') || (r[0] >= 'А' && r[0] <= 'Я')) { topics = append(topics, w) if len(topics) >= 3 { break } } } } return topics } func formatHistory(messages []llm.Message) string { var sb strings.Builder for _, m := range messages { role := "User" if m.Role == llm.RoleAssistant { role = "Assistant" } sb.WriteString(role) sb.WriteString(": ") content := m.Content if len(content) > 500 { content = content[:500] + "..." } sb.WriteString(content) sb.WriteString("\n") } return sb.String() } func detectLanguage(text string) string { cyrillicCount := 0 latinCount := 0 for _, r := range text { if r >= 'а' && r <= 'я' || r >= 'А' && r <= 'Я' { cyrillicCount++ } else if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' { latinCount++ } } if cyrillicCount > latinCount { return "ru" } return "en" }