Files
upscayl-server/task/action.go

158 lines
3.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package task
import (
"context"
"fmt"
"log"
"time"
"github.com/dustin/go-humanize"
"trle5.xyz/upscayl-server/upscayl"
)
func RunTasks(ctx context.Context) {
// 启动一个监听 context 取消的协程,防止死锁
// 当外部取消 context 时,唤醒所有因空队列而在 cond.Wait() 阻塞的线程
go func() {
<-ctx.Done()
taskList.cond.Broadcast()
}()
for {
// 检查 context 是否已经退出
if ctx.Err() != nil {
log.Println("Context canceled, exiting task loop")
return
}
taskList.rw.Lock()
// 如果队列为空,则使用锁进行等待
// 必须用 for 循环包裹 Wait()防止虚假唤醒Spurious Wakeup
for len(taskList.Tasks) == 0 && ctx.Err() == nil {
// Wait 会自动释放当前的锁,并将当前协程挂起。
// 被唤醒时,它会自动重新持有该锁。
taskList.cond.Wait()
}
// 再次检查是否是因为 context 取消而被唤醒的
if ctx.Err() != nil {
taskList.rw.Unlock()
log.Println("Context canceled, exiting task loop")
return
}
// 取出并移除第一个任务(先进先出 FIFO
taskList.Tasks[0].Status = StatusWorking
task := taskList.Tasks[0]
taskList.rw.Unlock()
// 在锁外执行任务,避免阻塞其他想添加任务的协程
log.Printf("Processing task [%s] use [%s] model", task.ID, task.Params.Model)
image, err := upscayl.Run(ctx, task.Params)
if err != nil {
log.Printf("Error processing task [%s]: %v", task.ID, err)
task.Status = StatusError
task.Error = err.Error()
} else {
log.Printf("Task [%s] completed in %d seconds, size %s", task.ID, image.Second, humanize.IBytes(uint64(image.Size)))
task.Elapsed = image.Second
task.Status = StatusDone
task.Size = image.Size
task.Format = image.Format
task.URL = fmt.Sprintf("/output/%s.%s", task.ID, image.Format)
}
taskList.rw.Lock()
taskList.Tasks = taskList.Tasks[1:]
taskList.OldTasks = append(taskList.OldTasks, task)
taskList.rw.Unlock()
}
}
func TaskCount() int {
taskList.rw.RLock()
defer taskList.rw.RUnlock()
return len(taskList.Tasks)
}
func Add(taskID string, hash string, params upscayl.Params) (*Task, error) {
err := params.Validate()
if err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
taskList.rw.Lock()
defer taskList.rw.Unlock()
task := Task{
ID: taskID,
Status: StatusPending,
Submit: time.Now(),
Hash: hash,
Params: params,
}
var oldTaskID string
for _, oldTask := range taskList.OldTasks {
if oldTask.Status == StatusDone && oldTask.Hash == task.Hash && upscayl.SameParams(oldTask.Params, task.Params) {
oldTaskID = oldTask.ID
task.Status = StatusDone
task.Size = oldTask.Size
task.Format = oldTask.Format
task.URL = oldTask.URL
break
}
}
if oldTaskID != "" {
log.Printf("Task [%s] has same file and params as old task [%s], using old task", taskID, oldTaskID)
taskList.OldTasks = append(taskList.OldTasks, task)
} else {
taskList.Tasks = append(taskList.Tasks, task)
taskList.cond.Signal()
}
return &task, nil
}
func Get(taskID string) (*Task, error) {
taskList.rw.RLock()
defer taskList.rw.RUnlock()
for _, task := range taskList.Tasks {
if task.ID == taskID {
return &task, nil
}
}
for _, task := range taskList.OldTasks {
if task.ID == taskID {
return &task, nil
}
}
return nil, fmt.Errorf("task not found")
}
func Cancel(taskID string) error {
taskList.rw.Lock()
defer taskList.rw.Unlock()
for i, task := range taskList.Tasks {
if task.ID == taskID {
taskList.Tasks = append(taskList.Tasks[:i], taskList.Tasks[i+1:]...)
return nil
}
}
for _, task := range taskList.OldTasks {
if task.ID == taskID {
return fmt.Errorf("task [%s] is already completed with status [%s]", taskID, task.Status)
}
}
return fmt.Errorf("task not found")
}