Files
trbot/database/yaml_db/yaml.go

244 lines
6.2 KiB
Go

package yaml_db
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"trbot/database/db_struct"
"trbot/utils"
"trbot/utils/configs"
"trbot/utils/task"
"trbot/utils/yaml"
"github.com/go-telegram/bot/models"
"github.com/reugn/go-quartz/job"
"github.com/reugn/go-quartz/quartz"
"github.com/rs/zerolog"
)
var YAMLDatabasePath = filepath.Join(configs.YAMLDatabaseDir, configs.YAMLFileName)
func Initialize(ctx context.Context) (*DataBaseYaml, error) {
var db DataBaseYaml
if configs.YAMLDatabaseDir == "" {
return nil, fmt.Errorf("yaml database path is empty")
}
err := db.ReadDatabase(ctx)
if err != nil {
return nil, fmt.Errorf("failed to read yaml database: %w", err)
}
err = task.ScheduleTask(ctx, task.Task{
Name: "save_yaml_database",
Group: "trbot",
Job: job.NewFunctionJobWithDesc(
func(ctx context.Context) (int, error) {
db.AutoSaveDatabaseHandler(ctx)
return 0, nil
},
"Save yaml database every 10 minutes",
),
Trigger: quartz.NewSimpleTrigger(10 * time.Minute),
})
if err != nil {
return nil, fmt.Errorf("failed to add auto save database task: %w", err)
}
return &db, nil
}
type DataBaseYaml struct {
rw sync.RWMutex
// 如果运行中希望程序强制读取新数据,在 YAML 数据库文件的开头添加 FORCEOVERWRITE: true 即可
ForceOverwrite bool `yaml:"FORCEOVERWRITE,omitempty"`
UpdateTimestamp int64 `yaml:"UpdateTimestamp"`
Chats []db_struct.ChatInfo `yaml:"Chats"`
}
func (db *DataBaseYaml)Name() string {
return "YAML"
}
func (db *DataBaseYaml)saveDatabaseNoLock(ctx context.Context) error {
db.UpdateTimestamp = time.Now().Unix()
err := yaml.SaveYAML(YAMLDatabasePath, &db)
if err != nil {
zerolog.Ctx(ctx).Error().
Err(err).
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Str("path", YAMLDatabasePath).
Msg("Failed to save database")
return fmt.Errorf("failed to save database: %w", err)
}
return nil
}
func (db *DataBaseYaml)readDatabaseNoLock(ctx context.Context) error {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
err := yaml.LoadYAML(YAMLDatabasePath, &db)
if err != nil {
if os.IsNotExist(err) {
logger.Warn().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Not found database file. Created new one")
// 如果是找不到文件,新建一个
err = yaml.SaveYAML(YAMLDatabasePath, &db)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to create empty database file")
return fmt.Errorf("failed to create empty database file: %w", err)
}
} else {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to read database file")
return fmt.Errorf("failed to read database file: %w", err)
}
}
return nil
}
func (db *DataBaseYaml)SaveDatabase(ctx context.Context) error {
db.rw.RLock()
defer db.rw.RUnlock()
return db.saveDatabaseNoLock(ctx)
}
func (db *DataBaseYaml)ReadDatabase(ctx context.Context) error {
db.rw.Lock()
defer db.rw.Unlock()
return db.readDatabaseNoLock(ctx)
}
// 获取 ID 信息
func (db *DataBaseYaml)GetChatInfo(ctx context.Context, id int64) (*db_struct.ChatInfo, error) {
db.rw.RLock()
defer db.rw.RUnlock()
for _, data := range db.Chats {
if data.ID == id {
return &data, nil
}
}
return nil, fmt.Errorf("ChatInfo not found")
}
// 初次添加群组时,获取必要信息
func (db *DataBaseYaml)InitChat(ctx context.Context, chat *models.Chat) error {
db.rw.Lock()
defer db.rw.Unlock()
for _, data := range db.Chats {
if data.ID == chat.ID {
return nil // 群组已存在,不重复添加
}
}
db.Chats = append(db.Chats, db_struct.ChatInfo{
ID: chat.ID,
ChatType: chat.Type,
ChatName: utils.ShowChatName(chat),
AddTime: time.Now().Format(time.RFC3339),
})
return db.saveDatabaseNoLock(ctx)
}
func (db *DataBaseYaml)InitUser(ctx context.Context, user *models.User) error {
db.rw.Lock()
defer db.rw.Unlock()
for _, data := range db.Chats {
if data.ID == user.ID {
return nil // 用户已存在,不重复添加
}
}
db.Chats = append(db.Chats, db_struct.ChatInfo{
ID: user.ID,
ChatType: models.ChatTypePrivate,
ChatName: utils.ShowUserName(user),
AddTime: time.Now().Format(time.RFC3339),
})
return db.saveDatabaseNoLock(ctx)
}
func (db *DataBaseYaml)IncrementalUsageCount(ctx context.Context, chatID int64, fieldName db_struct.UsageCount) error {
db.rw.Lock()
defer db.rw.Unlock()
for index, data := range db.Chats {
if data.ID == chatID {
db.UpdateTimestamp = time.Now().Unix() + 1
if data.UsageCount == nil { db.Chats[index].UsageCount = map[db_struct.UsageCount]int{} }
usage, isExist := data.UsageCount[fieldName]
if isExist {
db.Chats[index].UsageCount[fieldName] = usage + 1
} else {
db.Chats[index].UsageCount[fieldName] = 1
}
return nil
}
}
return fmt.Errorf("ChatInfo not found")
}
func (db *DataBaseYaml)RecordLatestData(ctx context.Context, chatID int64, fieldName db_struct.LatestData, value string) error {
db.rw.Lock()
defer db.rw.Unlock()
for index, data := range db.Chats {
if data.ID == chatID {
db.UpdateTimestamp = time.Now().Unix() + 1
if data.LatestData == nil { db.Chats[index].LatestData = map[db_struct.LatestData]string{} }
db.Chats[index].LatestData[fieldName] = value
return nil
}
}
return fmt.Errorf("ChatInfo not found")
}
func (db *DataBaseYaml)UpdateOperationStatus(ctx context.Context, chatID int64, fieldName db_struct.Status, value bool) error {
db.rw.Lock()
defer db.rw.Unlock()
for index, data := range db.Chats {
if data.ID == chatID {
db.UpdateTimestamp = time.Now().Unix() + 1
if data.Status == nil { db.Chats[index].Status = map[db_struct.Status]bool{} }
db.Chats[index].Status[fieldName] = value
return nil
}
}
return fmt.Errorf("ChatInfo not found")
}
func (db *DataBaseYaml)SetCustomFlag(ctx context.Context, chatID int64, fieldName db_struct.Flag, value string) error {
db.rw.Lock()
defer db.rw.Unlock()
for index, data := range db.Chats {
if data.ID == chatID {
db.UpdateTimestamp = time.Now().Unix() + 1
if data.Flag == nil { db.Chats[index].Flag = map[db_struct.Flag]string{} }
db.Chats[index].Flag[fieldName] = value
return nil
}
}
return fmt.Errorf("ChatInfo not found")
}