feat: spaces redesign, model selector, auth fixes
Spaces: - Perplexity-like UI with collaboration features - Space detail page with threads/members tabs - Invite members via email, role management - New space creation with icon/color picker Model selector: - Added Ollama client for free Auto model - GooSeek 1.0 via Timeweb (tariff-based) - Frontend model dropdown in ChatInput Auth & Infrastructure: - Fixed auth-svc missing from Dockerfile.all - Removed duplicate ratelimit_tiered.go (conflict) - Added Redis to api-gateway for rate limiting - Fixed Next.js proxy for local development UI improvements: - Redesigned login button in sidebar (gradient) - Settings page with tabs (account/billing/prefs) - Auth pages visual refresh Made-with: Cursor
This commit is contained in:
@@ -38,6 +38,7 @@ func (r *Repository) RunMigrations(ctx context.Context) error {
|
||||
avatar TEXT,
|
||||
role VARCHAR(50) DEFAULT 'user',
|
||||
tier VARCHAR(50) DEFAULT 'free',
|
||||
balance DECIMAL(12,2) DEFAULT 0,
|
||||
email_verified BOOLEAN DEFAULT FALSE,
|
||||
provider VARCHAR(50) DEFAULT 'local',
|
||||
provider_id VARCHAR(255),
|
||||
@@ -69,6 +70,8 @@ func (r *Repository) RunMigrations(ctx context.Context) error {
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_password_reset_tokens ON password_reset_tokens(token)`,
|
||||
|
||||
`ALTER TABLE auth_users ADD COLUMN IF NOT EXISTS balance DECIMAL(12,2) DEFAULT 0`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
@@ -125,18 +128,19 @@ func (r *Repository) CreateUser(ctx context.Context, email, password, name strin
|
||||
|
||||
func (r *Repository) GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
||||
query := `
|
||||
SELECT id, email, password_hash, name, avatar, role, tier, email_verified,
|
||||
SELECT id, email, password_hash, name, avatar, role, tier, balance, email_verified,
|
||||
provider, provider_id, last_login_at, created_at, updated_at
|
||||
FROM auth_users WHERE email = $1
|
||||
`
|
||||
|
||||
user := &User{}
|
||||
var lastLogin, avatar, providerID sql.NullString
|
||||
var avatar, providerID sql.NullString
|
||||
var lastLoginTime sql.NullTime
|
||||
var balance sql.NullFloat64
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, email).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.Name, &avatar,
|
||||
&user.Role, &user.Tier, &user.EmailVerified, &user.Provider,
|
||||
&user.Role, &user.Tier, &balance, &user.EmailVerified, &user.Provider,
|
||||
&providerID, &lastLoginTime, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
@@ -156,14 +160,16 @@ func (r *Repository) GetUserByEmail(ctx context.Context, email string) (*User, e
|
||||
if lastLoginTime.Valid {
|
||||
user.LastLoginAt = lastLoginTime.Time
|
||||
}
|
||||
_ = lastLogin
|
||||
if balance.Valid {
|
||||
user.Balance = balance.Float64
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *Repository) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
query := `
|
||||
SELECT id, email, password_hash, name, avatar, role, tier, email_verified,
|
||||
SELECT id, email, password_hash, name, avatar, role, tier, balance, email_verified,
|
||||
provider, provider_id, last_login_at, created_at, updated_at
|
||||
FROM auth_users WHERE id = $1
|
||||
`
|
||||
@@ -171,10 +177,11 @@ func (r *Repository) GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
user := &User{}
|
||||
var avatar, providerID sql.NullString
|
||||
var lastLoginTime sql.NullTime
|
||||
var balance sql.NullFloat64
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.Name, &avatar,
|
||||
&user.Role, &user.Tier, &user.EmailVerified, &user.Provider,
|
||||
&user.Role, &user.Tier, &balance, &user.EmailVerified, &user.Provider,
|
||||
&providerID, &lastLoginTime, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
@@ -194,6 +201,9 @@ func (r *Repository) GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
if lastLoginTime.Valid {
|
||||
user.LastLoginAt = lastLoginTime.Time
|
||||
}
|
||||
if balance.Valid {
|
||||
user.Balance = balance.Float64
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -254,6 +264,22 @@ func (r *Repository) UpdateRole(ctx context.Context, userID string, role UserRol
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) UpdateBalance(ctx context.Context, userID string, amount float64) error {
|
||||
_, err := r.db.ExecContext(ctx,
|
||||
"UPDATE auth_users SET balance = balance + $2, updated_at = NOW() WHERE id = $1",
|
||||
userID, amount,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) SetBalance(ctx context.Context, userID string, balance float64) error {
|
||||
_, err := r.db.ExecContext(ctx,
|
||||
"UPDATE auth_users SET balance = $2, updated_at = NOW() WHERE id = $1",
|
||||
userID, balance,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) CreateRefreshToken(ctx context.Context, userID, userAgent, ip string, duration time.Duration) (*RefreshToken, error) {
|
||||
token := generateSecureToken(32)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ type User struct {
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier"`
|
||||
Balance float64 `json:"balance"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"providerId,omitempty"`
|
||||
|
||||
@@ -22,6 +22,30 @@ type Space struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ThreadCount int `json:"threadCount,omitempty"`
|
||||
Members []*SpaceMember `json:"members,omitempty"`
|
||||
MemberCount int `json:"memberCount,omitempty"`
|
||||
}
|
||||
|
||||
type SpaceMember struct {
|
||||
ID string `json:"id"`
|
||||
SpaceID string `json:"spaceId"`
|
||||
UserID string `json:"userId"`
|
||||
Role string `json:"role"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
JoinedAt time.Time `json:"joinedAt"`
|
||||
}
|
||||
|
||||
type SpaceInvite struct {
|
||||
ID string `json:"id"`
|
||||
SpaceID string `json:"spaceId"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
InvitedBy string `json:"invitedBy"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type SpaceRepository struct {
|
||||
@@ -50,6 +74,28 @@ func (r *SpaceRepository) RunMigrations(ctx context.Context) error {
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_spaces_user ON spaces(user_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS space_members (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
space_id UUID NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL,
|
||||
role VARCHAR(20) DEFAULT 'member',
|
||||
joined_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(space_id, user_id)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_space_members_space ON space_members(space_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_space_members_user ON space_members(user_id)`,
|
||||
`CREATE TABLE IF NOT EXISTS space_invites (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
space_id UUID NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
|
||||
email VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(20) DEFAULT 'member',
|
||||
invited_by UUID NOT NULL,
|
||||
token VARCHAR(64) NOT NULL UNIQUE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_space_invites_token ON space_invites(token)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_space_invites_email ON space_invites(email)`,
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
@@ -175,3 +221,176 @@ func (r *SpaceRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetMembers(ctx context.Context, spaceID string) ([]*SpaceMember, error) {
|
||||
query := `
|
||||
SELECT sm.id, sm.space_id, sm.user_id, sm.role, sm.joined_at,
|
||||
COALESCE(u.email, '') as email,
|
||||
COALESCE(u.name, '') as name,
|
||||
COALESCE(u.avatar, '') as avatar
|
||||
FROM space_members sm
|
||||
LEFT JOIN auth_users u ON sm.user_id = u.id
|
||||
WHERE sm.space_id = $1
|
||||
ORDER BY sm.joined_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, spaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var members []*SpaceMember
|
||||
for rows.Next() {
|
||||
var m SpaceMember
|
||||
if err := rows.Scan(&m.ID, &m.SpaceID, &m.UserID, &m.Role, &m.JoinedAt, &m.Email, &m.Name, &m.Avatar); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
members = append(members, &m)
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) AddMember(ctx context.Context, spaceID, userID, role string) error {
|
||||
query := `
|
||||
INSERT INTO space_members (space_id, user_id, role)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (space_id, user_id) DO NOTHING
|
||||
`
|
||||
_, err := r.db.db.ExecContext(ctx, query, spaceID, userID, role)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) RemoveMember(ctx context.Context, spaceID, userID string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "DELETE FROM space_members WHERE space_id = $1 AND user_id = $2", spaceID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) UpdateMemberRole(ctx context.Context, spaceID, userID, role string) error {
|
||||
result, err := r.db.db.ExecContext(ctx, "UPDATE space_members SET role = $3 WHERE space_id = $1 AND user_id = $2", spaceID, userID, role)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) IsMember(ctx context.Context, spaceID, userID string) (bool, string, error) {
|
||||
var role string
|
||||
err := r.db.db.QueryRowContext(ctx, "SELECT role FROM space_members WHERE space_id = $1 AND user_id = $2", spaceID, userID).Scan(&role)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
return true, role, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetByMemberID(ctx context.Context, userID string) ([]*Space, error) {
|
||||
query := `
|
||||
SELECT DISTINCT s.id, s.user_id, s.name, s.description, s.icon, s.color,
|
||||
s.custom_instructions, s.default_focus_mode, s.default_model,
|
||||
s.is_public, s.settings, s.created_at, s.updated_at,
|
||||
(SELECT COUNT(*) FROM threads WHERE space_id = s.id) as thread_count,
|
||||
(SELECT COUNT(*) FROM space_members WHERE space_id = s.id) as member_count
|
||||
FROM spaces s
|
||||
LEFT JOIN space_members sm ON s.id = sm.space_id
|
||||
WHERE s.user_id = $1 OR sm.user_id = $1
|
||||
ORDER BY s.updated_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.db.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var spaces []*Space
|
||||
for rows.Next() {
|
||||
var s Space
|
||||
var settingsJSON []byte
|
||||
|
||||
if err := rows.Scan(
|
||||
&s.ID, &s.UserID, &s.Name, &s.Description, &s.Icon, &s.Color,
|
||||
&s.CustomInstructions, &s.DefaultFocusMode, &s.DefaultModel,
|
||||
&s.IsPublic, &settingsJSON, &s.CreatedAt, &s.UpdatedAt, &s.ThreadCount, &s.MemberCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
json.Unmarshal(settingsJSON, &s.Settings)
|
||||
spaces = append(spaces, &s)
|
||||
}
|
||||
|
||||
return spaces, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) CreateInvite(ctx context.Context, invite *SpaceInvite) error {
|
||||
query := `
|
||||
INSERT INTO space_invites (space_id, email, role, invited_by, token, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
return r.db.db.QueryRowContext(ctx, query,
|
||||
invite.SpaceID, invite.Email, invite.Role, invite.InvitedBy, invite.Token, invite.ExpiresAt,
|
||||
).Scan(&invite.ID, &invite.CreatedAt)
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetInviteByToken(ctx context.Context, token string) (*SpaceInvite, error) {
|
||||
query := `
|
||||
SELECT id, space_id, email, role, invited_by, token, expires_at, created_at
|
||||
FROM space_invites
|
||||
WHERE token = $1 AND expires_at > NOW()
|
||||
`
|
||||
var inv SpaceInvite
|
||||
err := r.db.db.QueryRowContext(ctx, query, token).Scan(
|
||||
&inv.ID, &inv.SpaceID, &inv.Email, &inv.Role, &inv.InvitedBy, &inv.Token, &inv.ExpiresAt, &inv.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inv, nil
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) DeleteInvite(ctx context.Context, id string) error {
|
||||
_, err := r.db.db.ExecContext(ctx, "DELETE FROM space_invites WHERE id = $1", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *SpaceRepository) GetInvitesBySpace(ctx context.Context, spaceID string) ([]*SpaceInvite, error) {
|
||||
query := `
|
||||
SELECT id, space_id, email, role, invited_by, token, expires_at, created_at
|
||||
FROM space_invites
|
||||
WHERE space_id = $1 AND expires_at > NOW()
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
rows, err := r.db.db.QueryContext(ctx, query, spaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var invites []*SpaceInvite
|
||||
for rows.Next() {
|
||||
var inv SpaceInvite
|
||||
if err := rows.Scan(&inv.ID, &inv.SpaceID, &inv.Email, &inv.Role, &inv.InvitedBy, &inv.Token, &inv.ExpiresAt, &inv.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invites = append(invites, &inv)
|
||||
}
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
@@ -79,6 +79,11 @@ type ProviderConfig struct {
|
||||
|
||||
func NewClient(cfg ProviderConfig) (Client, error) {
|
||||
switch cfg.ProviderID {
|
||||
case "ollama":
|
||||
return NewOllamaClient(OllamaConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
ModelKey: cfg.ModelKey,
|
||||
})
|
||||
case "timeweb":
|
||||
return NewTimewebClient(TimewebConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
|
||||
233
backend/internal/llm/ollama.go
Normal file
233
backend/internal/llm/ollama.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
baseClient
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
BaseURL string
|
||||
ModelKey string
|
||||
}
|
||||
|
||||
func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) {
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "http://ollama:11434"
|
||||
}
|
||||
|
||||
modelKey := cfg.ModelKey
|
||||
if modelKey == "" {
|
||||
modelKey = "llama3.2"
|
||||
}
|
||||
|
||||
return &OllamaClient{
|
||||
baseClient: baseClient{
|
||||
providerID: "ollama",
|
||||
modelKey: modelKey,
|
||||
},
|
||||
httpClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
baseURL: baseURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ollamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ollamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Options *ollamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ollamaOptions struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message ollamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
func (c *OllamaClient) StreamText(ctx context.Context, req StreamRequest) (<-chan StreamChunk, error) {
|
||||
messages := make([]ollamaMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
messages = append(messages, ollamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
chatReq := ollamaChatRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
if req.Options.MaxTokens > 0 || req.Options.Temperature > 0 || req.Options.TopP > 0 || len(req.Options.StopWords) > 0 {
|
||||
chatReq.Options = &ollamaOptions{}
|
||||
if req.Options.MaxTokens > 0 {
|
||||
chatReq.Options.NumPredict = req.Options.MaxTokens
|
||||
}
|
||||
if req.Options.Temperature > 0 {
|
||||
chatReq.Options.Temperature = req.Options.Temperature
|
||||
}
|
||||
if req.Options.TopP > 0 {
|
||||
chatReq.Options.TopP = req.Options.TopP
|
||||
}
|
||||
if len(req.Options.StopWords) > 0 {
|
||||
chatReq.Options.Stop = req.Options.StopWords
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(chatReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/chat", c.baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Ollama API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ch := make(chan StreamChunk, 100)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var streamResp ollamaChatResponse
|
||||
if err := json.Unmarshal([]byte(line), &streamResp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if streamResp.Message.Content != "" {
|
||||
ch <- StreamChunk{ContentChunk: streamResp.Message.Content}
|
||||
}
|
||||
|
||||
if streamResp.Done {
|
||||
ch <- StreamChunk{FinishReason: "stop"}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) GenerateText(ctx context.Context, req StreamRequest) (string, error) {
|
||||
messages := make([]ollamaMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
messages = append(messages, ollamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
chatReq := ollamaChatRequest{
|
||||
Model: c.modelKey,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
if req.Options.MaxTokens > 0 || req.Options.Temperature > 0 || req.Options.TopP > 0 {
|
||||
chatReq.Options = &ollamaOptions{}
|
||||
if req.Options.MaxTokens > 0 {
|
||||
chatReq.Options.NumPredict = req.Options.MaxTokens
|
||||
}
|
||||
if req.Options.Temperature > 0 {
|
||||
chatReq.Options.Temperature = req.Options.Temperature
|
||||
}
|
||||
if req.Options.TopP > 0 {
|
||||
chatReq.Options.TopP = req.Options.TopP
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(chatReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/chat", c.baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("Ollama API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var chatResp ollamaChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if chatResp.Message.Content == "" {
|
||||
return "", errors.New("empty response from Ollama")
|
||||
}
|
||||
|
||||
return chatResp.Message.Content, nil
|
||||
}
|
||||
Reference in New Issue
Block a user