package computer import ( "context" "log" "sync" "time" "github.com/robfig/cron/v3" ) type Scheduler struct { taskRepo TaskRepository computer *Computer cron *cron.Cron jobs map[string]cron.EntryID running map[string]bool mu sync.RWMutex stopCh chan struct{} } func NewScheduler(taskRepo TaskRepository, computer *Computer) *Scheduler { return &Scheduler{ taskRepo: taskRepo, computer: computer, cron: cron.New(cron.WithSeconds()), jobs: make(map[string]cron.EntryID), running: make(map[string]bool), stopCh: make(chan struct{}), } } func (s *Scheduler) Start(ctx context.Context) { s.cron.Start() go s.pollScheduledTasks(ctx) log.Println("[Scheduler] Started") } func (s *Scheduler) Stop() { close(s.stopCh) s.cron.Stop() log.Println("[Scheduler] Stopped") } func (s *Scheduler) pollScheduledTasks(ctx context.Context) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() s.loadScheduledTasks(ctx) for { select { case <-ctx.Done(): return case <-s.stopCh: return case <-ticker.C: s.checkAndExecute(ctx) } } } func (s *Scheduler) loadScheduledTasks(ctx context.Context) { tasks, err := s.taskRepo.GetScheduled(ctx) if err != nil { log.Printf("[Scheduler] Failed to load scheduled tasks: %v", err) return } for _, task := range tasks { if task.Schedule != nil && task.Schedule.Enabled { s.scheduleTask(&task) } } log.Printf("[Scheduler] Loaded %d scheduled tasks", len(tasks)) } func (s *Scheduler) scheduleTask(task *ComputerTask) error { s.mu.Lock() defer s.mu.Unlock() if oldID, exists := s.jobs[task.ID]; exists { s.cron.Remove(oldID) } if task.Schedule == nil || !task.Schedule.Enabled { return nil } var entryID cron.EntryID var err error switch task.Schedule.Type { case "cron": if task.Schedule.CronExpr == "" { return nil } entryID, err = s.cron.AddFunc(task.Schedule.CronExpr, func() { s.executeScheduledTask(task.ID) }) case "interval": if task.Schedule.Interval <= 0 { return nil } cronExpr := s.intervalToCron(task.Schedule.Interval) entryID, err = s.cron.AddFunc(cronExpr, func() { s.executeScheduledTask(task.ID) }) case "once": go func() { if task.Schedule.NextRun.After(time.Now()) { time.Sleep(time.Until(task.Schedule.NextRun)) } s.executeScheduledTask(task.ID) }() return nil case "daily": entryID, err = s.cron.AddFunc("0 0 9 * * *", func() { s.executeScheduledTask(task.ID) }) case "hourly": entryID, err = s.cron.AddFunc("0 0 * * * *", func() { s.executeScheduledTask(task.ID) }) case "weekly": entryID, err = s.cron.AddFunc("0 0 9 * * 1", func() { s.executeScheduledTask(task.ID) }) case "monthly": entryID, err = s.cron.AddFunc("0 0 9 1 * *", func() { s.executeScheduledTask(task.ID) }) default: return nil } if err != nil { log.Printf("[Scheduler] Failed to schedule task %s: %v", task.ID, err) return err } s.jobs[task.ID] = entryID log.Printf("[Scheduler] Scheduled task %s with type %s", task.ID, task.Schedule.Type) return nil } func (s *Scheduler) intervalToCron(seconds int) string { if seconds < 60 { return "*/30 * * * * *" } if seconds < 3600 { minutes := seconds / 60 return "0 */" + itoa(minutes) + " * * * *" } if seconds < 86400 { hours := seconds / 3600 return "0 0 */" + itoa(hours) + " * * *" } return "0 0 0 * * *" } func itoa(i int) string { if i < 10 { return string(rune('0' + i)) } return "" } func (s *Scheduler) executeScheduledTask(taskID string) { s.mu.Lock() if s.running[taskID] { s.mu.Unlock() log.Printf("[Scheduler] Task %s is already running, skipping", taskID) return } s.running[taskID] = true s.mu.Unlock() defer func() { s.mu.Lock() delete(s.running, taskID) s.mu.Unlock() }() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) defer cancel() task, err := s.taskRepo.GetByID(ctx, taskID) if err != nil { log.Printf("[Scheduler] Failed to get task %s: %v", taskID, err) return } if task.Schedule != nil { if task.Schedule.ExpiresAt != nil && time.Now().After(*task.Schedule.ExpiresAt) { log.Printf("[Scheduler] Task %s has expired, removing", taskID) s.Cancel(taskID) return } if task.Schedule.MaxRuns > 0 && task.Schedule.RunCount >= task.Schedule.MaxRuns { log.Printf("[Scheduler] Task %s reached max runs (%d), removing", taskID, task.Schedule.MaxRuns) s.Cancel(taskID) return } } log.Printf("[Scheduler] Executing scheduled task %s (run #%d)", taskID, task.RunCount+1) _, err = s.computer.Execute(ctx, task.UserID, task.Query, ExecuteOptions{ Async: false, Context: task.Memory, }) if err != nil { log.Printf("[Scheduler] Task %s execution failed: %v", taskID, err) } else { log.Printf("[Scheduler] Task %s completed successfully", taskID) } task.RunCount++ if task.Schedule != nil { task.Schedule.RunCount = task.RunCount task.Schedule.NextRun = s.calculateNextRun(task.Schedule) task.NextRunAt = &task.Schedule.NextRun } task.UpdatedAt = time.Now() if err := s.taskRepo.Update(ctx, task); err != nil { log.Printf("[Scheduler] Failed to update task %s: %v", taskID, err) } } func (s *Scheduler) calculateNextRun(schedule *Schedule) time.Time { switch schedule.Type { case "interval": return time.Now().Add(time.Duration(schedule.Interval) * time.Second) case "hourly": return time.Now().Add(time.Hour).Truncate(time.Hour) case "daily": next := time.Now().Add(24 * time.Hour) return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location()) case "weekly": next := time.Now().Add(7 * 24 * time.Hour) return time.Date(next.Year(), next.Month(), next.Day(), 9, 0, 0, 0, next.Location()) case "monthly": next := time.Now().AddDate(0, 1, 0) return time.Date(next.Year(), next.Month(), 1, 9, 0, 0, 0, next.Location()) default: return time.Now().Add(time.Hour) } } func (s *Scheduler) checkAndExecute(ctx context.Context) { tasks, err := s.taskRepo.GetScheduled(ctx) if err != nil { return } now := time.Now() for _, task := range tasks { if task.NextRunAt != nil && task.NextRunAt.Before(now) { if task.Schedule != nil && task.Schedule.Enabled { go s.executeScheduledTask(task.ID) } } } } func (s *Scheduler) Schedule(taskID string, schedule Schedule) error { ctx := context.Background() task, err := s.taskRepo.GetByID(ctx, taskID) if err != nil { return err } task.Schedule = &schedule task.Schedule.Enabled = true task.Schedule.NextRun = s.calculateNextRun(&schedule) task.NextRunAt = &task.Schedule.NextRun task.Status = StatusScheduled task.UpdatedAt = time.Now() if err := s.taskRepo.Update(ctx, task); err != nil { return err } return s.scheduleTask(task) } func (s *Scheduler) Cancel(taskID string) error { s.mu.Lock() defer s.mu.Unlock() if entryID, exists := s.jobs[taskID]; exists { s.cron.Remove(entryID) delete(s.jobs, taskID) } ctx := context.Background() task, err := s.taskRepo.GetByID(ctx, taskID) if err != nil { return err } if task.Schedule != nil { task.Schedule.Enabled = false } task.Status = StatusCancelled task.UpdatedAt = time.Now() return s.taskRepo.Update(ctx, task) } func (s *Scheduler) Pause(taskID string) error { s.mu.Lock() defer s.mu.Unlock() if entryID, exists := s.jobs[taskID]; exists { s.cron.Remove(entryID) delete(s.jobs, taskID) } ctx := context.Background() task, err := s.taskRepo.GetByID(ctx, taskID) if err != nil { return err } if task.Schedule != nil { task.Schedule.Enabled = false } task.UpdatedAt = time.Now() return s.taskRepo.Update(ctx, task) } func (s *Scheduler) Resume(taskID string) error { ctx := context.Background() task, err := s.taskRepo.GetByID(ctx, taskID) if err != nil { return err } if task.Schedule != nil { task.Schedule.Enabled = true task.Schedule.NextRun = s.calculateNextRun(task.Schedule) task.NextRunAt = &task.Schedule.NextRun } task.Status = StatusScheduled task.UpdatedAt = time.Now() if err := s.taskRepo.Update(ctx, task); err != nil { return err } return s.scheduleTask(task) } func (s *Scheduler) GetScheduledTasks() []string { s.mu.RLock() defer s.mu.RUnlock() result := make([]string, 0, len(s.jobs)) for taskID := range s.jobs { result = append(result, taskID) } return result } func (s *Scheduler) IsRunning(taskID string) bool { s.mu.RLock() defer s.mu.RUnlock() return s.running[taskID] }