package llm import ( "context" "encoding/json" "errors" "io" "github.com/sashabaranov/go-openai" ) type OpenAIClient struct { baseClient client *openai.Client } func NewOpenAIClient(cfg ProviderConfig) (*OpenAIClient, error) { config := openai.DefaultConfig(cfg.APIKey) if cfg.BaseURL != "" { config.BaseURL = cfg.BaseURL } return &OpenAIClient{ baseClient: baseClient{ providerID: cfg.ProviderID, modelKey: cfg.ModelKey, }, client: openai.NewClientWithConfig(config), }, nil } func (c *OpenAIClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) { messages := make([]openai.ChatCompletionMessage, 0, len(req.Messages)) for _, m := range req.Messages { msg := openai.ChatCompletionMessage{ Role: string(m.Role), Content: m.Content, } if m.Name != "" { msg.Name = m.Name } if m.ToolCallID != "" { msg.ToolCallID = m.ToolCallID } if len(m.ToolCalls) > 0 { msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls)) for i, tc := range m.ToolCalls { args, _ := json.Marshal(tc.Arguments) msg.ToolCalls[i] = openai.ToolCall{ ID: tc.ID, Type: openai.ToolTypeFunction, Function: openai.FunctionCall{ Name: tc.Name, Arguments: string(args), }, } } } messages = append(messages, msg) } chatReq := openai.ChatCompletionRequest{ Model: c.modelKey, Messages: messages, Stream: true, } if req.Options.MaxTokens > 0 { chatReq.MaxTokens = req.Options.MaxTokens } if req.Options.Temperature > 0 { chatReq.Temperature = float32(req.Options.Temperature) } if req.Options.TopP > 0 { chatReq.TopP = float32(req.Options.TopP) } if len(req.Tools) > 0 { chatReq.Tools = make([]openai.Tool, len(req.Tools)) for i, t := range req.Tools { chatReq.Tools[i] = openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: t.Name, Description: t.Description, Parameters: t.Schema, }, } } } stream, err := c.client.CreateChatCompletionStream(ctx, chatReq) if err != nil { return nil, err } ch := make(chan StreamChunk, 100) go func() { defer close(ch) defer stream.Close() toolCalls := make(map[int]*ToolCall) for { response, err := stream.Recv() if errors.Is(err, io.EOF) { if len(toolCalls) > 0 { calls := make([]ToolCall, 0, len(toolCalls)) for _, tc := range toolCalls { calls = append(calls, *tc) } ch <- StreamChunk{ToolCallChunk: calls} } return } if err != nil { return } if len(response.Choices) == 0 { continue } delta := response.Choices[0].Delta if delta.Content != "" { ch <- StreamChunk{ContentChunk: delta.Content} } for _, tc := range delta.ToolCalls { idx := *tc.Index if _, ok := toolCalls[idx]; !ok { toolCalls[idx] = &ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: make(map[string]interface{}), } } if tc.Function.Arguments != "" { existing := toolCalls[idx] var args map[string]interface{} if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil { for k, v := range args { existing.Arguments[k] = v } } } } if response.Choices[0].FinishReason != "" { ch <- StreamChunk{FinishReason: string(response.Choices[0].FinishReason)} } } }() return ch, nil } func (c *OpenAIClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) { ch, err := c.StreamText(ctx, req) if err != nil { return "", err } return readAllChunks(ch), nil }