191 lines
4.6 KiB
Go
191 lines
4.6 KiB
Go
package upscayl
|
|
|
|
import (
|
|
"archive/zip"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"trle5.xyz/upscayl-server/configs"
|
|
)
|
|
|
|
var models = []string{
|
|
"https://github.com/upscayl/upscayl/raw/refs/heads/main/resources/models/upscayl-lite-4x.bin",
|
|
"https://github.com/upscayl/upscayl/raw/refs/heads/main/resources/models/upscayl-lite-4x.param",
|
|
}
|
|
|
|
func CheckUpscayl() {
|
|
file, err := os.Stat(configs.BinaryName)
|
|
if err != nil && os.IsNotExist(err) || file.Size() == 0 {
|
|
log.Println("Upscayl not found or is empty, downloading...")
|
|
err = downloadUpscayl()
|
|
if err != nil {
|
|
log.Fatalln("failed to download upscayl: ", err)
|
|
}
|
|
file, err = os.Stat(configs.BinaryName)
|
|
if err != nil {
|
|
log.Fatalln("failed to find upscayl: ", err)
|
|
}
|
|
}
|
|
|
|
_, err = os.Stat(configs.ModelsDir)
|
|
if os.IsNotExist(err) {
|
|
log.Println("Downloading models...")
|
|
err = downloadModels()
|
|
if err != nil {
|
|
log.Fatalln("failed to download models: ", err)
|
|
}
|
|
}
|
|
|
|
if file.Mode() != 0755 {
|
|
err = os.Chmod(configs.BinaryName, 0755)
|
|
if err != nil {
|
|
log.Fatalln("failed to make upscayl executable: ", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func getLatestRelease() (*release, error) {
|
|
client := &http.Client{Timeout: time.Second * 10}
|
|
resp, err := client.Get("https://api.github.com/repos/upscayl/upscayl-ncnn/releases/latest")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get upscayl latest release: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to get upscayl latest release data with status: %s", resp.Status)
|
|
}
|
|
|
|
var release release
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&release)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode upscayl release response: %w", err)
|
|
}
|
|
|
|
return &release, nil
|
|
}
|
|
|
|
func downloadUpscayl() error {
|
|
release, err := getLatestRelease()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get upscayl latest release: %w", err)
|
|
} else {
|
|
if release.ID == 0 {
|
|
return fmt.Errorf("no release found in upscayl response")
|
|
}
|
|
if release.TagName == "" {
|
|
return fmt.Errorf("no tag name found in upscayl release")
|
|
}
|
|
if len(release.Assets) == 0 {
|
|
return fmt.Errorf("no assets found in upscayl release")
|
|
}
|
|
}
|
|
|
|
var (
|
|
url string
|
|
sha string
|
|
)
|
|
|
|
for _, asset := range release.Assets {
|
|
if strings.HasSuffix(asset.Name, configs.Platform + ".zip") {
|
|
if asset.State == "uploaded" && asset.Size > 0 {
|
|
url = asset.BrowserDownloadURL
|
|
sha = strings.TrimPrefix(asset.Digest, "sha256:")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if url == "" {
|
|
return fmt.Errorf("no valid upscayl release found for %s", configs.Platform)
|
|
}
|
|
|
|
client := &http.Client{Timeout: time.Second * 30}
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download upscayl release zip: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("download failed with status: %s", resp.Status)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
|
|
_, err = io.Copy(&buf, resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to copy upscayl release zip data: %w", err)
|
|
}
|
|
|
|
hash := sha256.Sum256(buf.Bytes())
|
|
|
|
if hex.EncodeToString(hash[:]) != sha {
|
|
return fmt.Errorf("downloaded upscayl release zip checksum does not match expected value")
|
|
}
|
|
|
|
zipReader, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read upscayl binary zip: %w", err)
|
|
}
|
|
|
|
fileInZip, err := zipReader.Open(fmt.Sprintf("%s-%s/%s", release.Name, configs.Platform, configs.BinaryName))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open upscayl binary from zip: %w", err)
|
|
}
|
|
|
|
file, err := os.OpenFile(configs.BinaryName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
_, err = io.Copy(file, fileInZip)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func downloadModels() error {
|
|
err := os.MkdirAll(configs.ModelsDir, 0755)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create models directory: %w", err)
|
|
}
|
|
|
|
client := &http.Client{Timeout: time.Second * 30}
|
|
|
|
for _, url := range models {
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download upscayl model: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
file, err := os.OpenFile(filepath.Join(configs.ModelsDir, filepath.Base(url)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create model file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
_, err = io.Copy(file, resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write model file: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|