package llm import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" ) type OllamaClient struct { baseClient httpClient *http.Client baseURL string } type OllamaConfig struct { BaseURL string ModelKey string } func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "http://ollama:11434" } modelKey := cfg.ModelKey if modelKey == "" { modelKey = "llama3.2" } return &OllamaClient{ baseClient: baseClient{ providerID: "ollama", modelKey: modelKey, }, httpClient: &http.Client{ Timeout: 300 * time.Second, }, baseURL: baseURL, }, nil } type ollamaChatRequest struct { Model string `json:"model"` Messages []ollamaMessage `json:"messages"` Stream bool `json:"stream"` Options *ollamaOptions `json:"options,omitempty"` } type ollamaMessage struct { Role string `json:"role"` Content string `json:"content"` } type ollamaOptions struct { Temperature float64 `json:"temperature,omitempty"` NumPredict int `json:"num_predict,omitempty"` TopP float64 `json:"top_p,omitempty"` Stop []string `json:"stop,omitempty"` } type ollamaChatResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Message ollamaMessage `json:"message"` Done bool `json:"done"` } func (c *OllamaClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) { messages := make([]ollamaMessage, 0, len(req.Messages)) for _, m := range req.Messages { messages = append(messages, ollamaMessage{ Role: string(m.Role), Content: m.Content, }) } chatReq := ollamaChatRequest{ Model: c.modelKey, Messages: messages, Stream: true, } if req.Options.MaxTokens > 0 || req.Options.Temperature > 0 || req.Options.TopP > 0 || len(req.Options.StopWords) > 0 { chatReq.Options = &ollamaOptions{} if req.Options.MaxTokens > 0 { chatReq.Options.NumPredict = req.Options.MaxTokens } if req.Options.Temperature > 0 { chatReq.Options.Temperature = req.Options.Temperature } if req.Options.TopP > 0 { chatReq.Options.TopP = req.Options.TopP } if len(req.Options.StopWords) > 0 { chatReq.Options.Stop = req.Options.StopWords } } body, err := json.Marshal(chatReq) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } url := fmt.Sprintf("%s/api/chat", c.baseURL) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("Ollama API error: status %d, body: %s", resp.StatusCode, string(body)) } ch := make(chan StreamChunk, 100) go func() { defer close(ch) defer resp.Body.Close() reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadString('\n') if err != nil { if err != io.EOF { return } return } line = strings.TrimSpace(line) if line == "" { continue } var streamResp ollamaChatResponse if err := json.Unmarshal([]byte(line), &streamResp); err != nil { continue } if streamResp.Message.Content != "" { ch <- StreamChunk{ContentChunk: streamResp.Message.Content} } if streamResp.Done { ch <- StreamChunk{FinishReason: "stop"} return } } }() return ch, nil } func (c *OllamaClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) { messages := make([]ollamaMessage, 0, len(req.Messages)) for _, m := range req.Messages { messages = append(messages, ollamaMessage{ Role: string(m.Role), Content: m.Content, }) } chatReq := ollamaChatRequest{ Model: c.modelKey, Messages: messages, Stream: false, } if req.Options.MaxTokens > 0 || req.Options.Temperature > 0 || req.Options.TopP > 0 { chatReq.Options = &ollamaOptions{} if req.Options.MaxTokens > 0 { chatReq.Options.NumPredict = req.Options.MaxTokens } if req.Options.Temperature > 0 { chatReq.Options.Temperature = req.Options.Temperature } if req.Options.TopP > 0 { chatReq.Options.TopP = req.Options.TopP } } body, err := json.Marshal(chatReq) if err != nil { return "", fmt.Errorf("failed to marshal request: %w", err) } url := fmt.Sprintf("%s/api/chat", c.baseURL) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(httpReq) if err != nil { return "", fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("Ollama API error: status %d, body: %s", resp.StatusCode, string(body)) } var chatResp ollamaChatResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } if chatResp.Message.Content == "" { return "", errors.New("empty response from Ollama") } return chatResp.Message.Content, nil }