package computer import ( "errors" "sort" "github.com/gooseek/backend/internal/llm" ) type RoutingRule struct { TaskType TaskType Preferred []llm.ModelCapability Fallback []string MaxCost float64 MaxLatency int } type Router struct { registry *llm.ModelRegistry rules map[TaskType]RoutingRule } func NewRouter(registry *llm.ModelRegistry) *Router { r := &Router{ registry: registry, rules: make(map[TaskType]RoutingRule), } r.rules[TaskResearch] = RoutingRule{ TaskType: TaskResearch, Preferred: []llm.ModelCapability{llm.CapSearch, llm.CapLongContext}, Fallback: []string{"gemini-1.5-pro", "gpt-4o"}, MaxCost: 0.1, } r.rules[TaskCode] = RoutingRule{ TaskType: TaskCode, Preferred: []llm.ModelCapability{llm.CapCoding}, Fallback: []string{"claude-3-sonnet", "claude-3-opus", "gpt-4o"}, MaxCost: 0.2, } r.rules[TaskAnalysis] = RoutingRule{ TaskType: TaskAnalysis, Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapMath}, Fallback: []string{"claude-3-opus", "gpt-4o"}, MaxCost: 0.15, } r.rules[TaskDesign] = RoutingRule{ TaskType: TaskDesign, Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCreative}, Fallback: []string{"claude-3-opus", "gpt-4o"}, MaxCost: 0.15, } r.rules[TaskDeploy] = RoutingRule{ TaskType: TaskDeploy, Preferred: []llm.ModelCapability{llm.CapCoding, llm.CapFast}, Fallback: []string{"claude-3-sonnet", "gpt-4o-mini"}, MaxCost: 0.05, } r.rules[TaskMonitor] = RoutingRule{ TaskType: TaskMonitor, Preferred: []llm.ModelCapability{llm.CapFast}, Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"}, MaxCost: 0.02, } r.rules[TaskReport] = RoutingRule{ TaskType: TaskReport, Preferred: []llm.ModelCapability{llm.CapCreative, llm.CapLongContext}, Fallback: []string{"claude-3-opus", "gpt-4o"}, MaxCost: 0.1, } r.rules[TaskCommunicate] = RoutingRule{ TaskType: TaskCommunicate, Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCreative}, Fallback: []string{"gpt-4o-mini", "gemini-1.5-flash"}, MaxCost: 0.02, } r.rules[TaskTransform] = RoutingRule{ TaskType: TaskTransform, Preferred: []llm.ModelCapability{llm.CapFast, llm.CapCoding}, Fallback: []string{"gpt-4o-mini", "claude-3-sonnet"}, MaxCost: 0.03, } r.rules[TaskValidate] = RoutingRule{ TaskType: TaskValidate, Preferred: []llm.ModelCapability{llm.CapReasoning}, Fallback: []string{"gpt-4o", "claude-3-sonnet"}, MaxCost: 0.05, } return r } func (r *Router) Route(task *SubTask, budget float64) (llm.Client, llm.ModelSpec, error) { if task.ModelID != "" { client, spec, err := r.registry.GetByID(task.ModelID) if err == nil && spec.CostPer1K <= budget { return client, spec, nil } } if len(task.RequiredCaps) > 0 { for _, cap := range task.RequiredCaps { client, spec, err := r.registry.GetBest(cap) if err == nil && spec.CostPer1K <= budget { return client, spec, nil } } } rule, ok := r.rules[task.Type] if ok { for _, cap := range rule.Preferred { client, spec, err := r.registry.GetBest(cap) if err == nil && spec.CostPer1K <= budget { return client, spec, nil } } for _, modelID := range rule.Fallback { client, spec, err := r.registry.GetByID(modelID) if err == nil && spec.CostPer1K <= budget { return client, spec, nil } } } models := r.registry.GetAll() if len(models) == 0 { return nil, llm.ModelSpec{}, errors.New("no models available") } sort.Slice(models, func(i, j int) bool { return models[i].CostPer1K < models[j].CostPer1K }) for _, spec := range models { if spec.CostPer1K <= budget { client, err := r.registry.GetClient(spec.ID) if err == nil { return client, spec, nil } } } client, err := r.registry.GetClient(models[0].ID) if err != nil { return nil, llm.ModelSpec{}, err } return client, models[0], nil } func (r *Router) RouteMultiple(task *SubTask, count int, budget float64) ([]llm.Client, []llm.ModelSpec, error) { var clients []llm.Client var specs []llm.ModelSpec usedModels := make(map[string]bool) perModelBudget := budget / float64(count) rule, ok := r.rules[task.Type] if !ok { rule = RoutingRule{ Preferred: []llm.ModelCapability{llm.CapReasoning, llm.CapCoding, llm.CapFast}, } } for _, cap := range rule.Preferred { if len(clients) >= count { break } models := r.registry.GetAllWithCapability(cap) for _, spec := range models { if len(clients) >= count { break } if usedModels[spec.ID] { continue } if spec.CostPer1K > perModelBudget { continue } client, err := r.registry.GetClient(spec.ID) if err == nil { clients = append(clients, client) specs = append(specs, spec) usedModels[spec.ID] = true } } } if len(clients) < count { models := r.registry.GetAll() for _, spec := range models { if len(clients) >= count { break } if usedModels[spec.ID] { continue } client, err := r.registry.GetClient(spec.ID) if err == nil { clients = append(clients, client) specs = append(specs, spec) usedModels[spec.ID] = true } } } if len(clients) == 0 { return nil, nil, errors.New("no models available for consensus") } return clients, specs, nil } func (r *Router) SetRule(taskType TaskType, rule RoutingRule) { r.rules[taskType] = rule } func (r *Router) GetRule(taskType TaskType) (RoutingRule, bool) { rule, ok := r.rules[taskType] return rule, ok } func (r *Router) EstimateCost(task *SubTask, inputTokens, outputTokens int) float64 { _, spec, err := r.Route(task, 1.0) if err != nil { return 0.01 } totalTokens := inputTokens + outputTokens return spec.CostPer1K * float64(totalTokens) / 1000.0 }