Files
trbot/database/yaml_db/yaml.go

402 lines
12 KiB
Go

package yaml_db
import (
"context"
"fmt"
"os"
"path/filepath"
"reflect"
"time"
"trbot/database/db_struct"
"trbot/utils"
"trbot/utils/configs"
"trbot/utils/yaml"
"github.com/go-telegram/bot/models"
"github.com/rs/zerolog"
)
var Database DataBaseYaml
var YAMLDatabasePath = filepath.Join(configs.YAMLDatabaseDir, configs.YAMLFileName)
// 需要重构,错误信息不足
type DataBaseYaml struct {
// 如果运行中希望程序强制读取新数据,在 YAML 数据库文件的开头添加 FORCEOVERWRITE: true 即可
ForceOverwrite bool `yaml:"FORCEOVERWRITE,omitempty"`
UpdateTimestamp int64 `yaml:"UpdateTimestamp"`
Data struct {
ChatInfo []db_struct.ChatInfo `yaml:"ChatInfo"`
} `yaml:"Data"`
}
func InitializeDB(ctx context.Context) error {
if configs.YAMLDatabaseDir != "" {
err := ReadDatabase(ctx)
if err != nil {
return fmt.Errorf("failed to read yaml database: %s", err)
}
return nil
} else {
return fmt.Errorf("yaml database path is empty")
}
}
func SaveDatabase(ctx context.Context) error {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
Database.UpdateTimestamp = time.Now().Unix()
err := yaml.SaveYAML(YAMLDatabasePath, &Database)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to save database")
return fmt.Errorf("failed to save database: %w", err)
}
return nil
}
func ReadDatabase(ctx context.Context) error {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
err := yaml.LoadYAML(YAMLDatabasePath, &Database)
if err != nil {
if os.IsNotExist(err) {
logger.Warn().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Not found database file. Created new one")
// 如果是找不到文件,新建一个
err = yaml.SaveYAML(YAMLDatabasePath, &Database)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to create empty database file")
return fmt.Errorf("failed to create empty database file: %w", err)
}
} else {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to read database file")
return fmt.Errorf("failed to read database file: %w", err)
}
}
return nil
}
func ReadYamlDB(ctx context.Context, pathToFile string) (*DataBaseYaml, error) {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
var tempDatabase *DataBaseYaml
err := yaml.LoadYAML(pathToFile, &tempDatabase)
if err != nil {
if os.IsNotExist(err) {
logger.Warn().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Not found database file. Created new one")
// 如果是找不到文件,新建一个
err = yaml.SaveYAML(YAMLDatabasePath, &tempDatabase)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to create empty database file")
return nil, fmt.Errorf("failed to create empty database file: %w", err)
}
} else {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to read database file")
return nil, fmt.Errorf("failed to read database file: %w", err)
}
}
return tempDatabase, nil
}
// 路径 文件名 YAML 数据结构体
func SaveYamlDB(ctx context.Context, path, name string, tempDatabase interface{}) error {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
err := yaml.SaveYAML(filepath.Join(path, name), &tempDatabase)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to save database")
return fmt.Errorf("failed to save database: %w", err)
}
return nil
}
func AutoSaveDatabaseHandler(ctx context.Context) {
logger := zerolog.Ctx(ctx).
With().
Str("database", "yaml").
Str(utils.GetCurrentFuncName()).
Logger()
// 先读取一下数据库文件
savedDatabase, err := ReadYamlDB(ctx, YAMLDatabasePath)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to read database file")
} else {
// 如果读取数据库文件时发现数据库为空,使用当前的数据重建数据库文件
if savedDatabase == nil {
logger.Warn().
Str("path", YAMLDatabasePath).
Msg("The database file is empty, recover database file using current data")
err = SaveYamlDB(ctx, configs.YAMLDatabaseDir, configs.YAMLFileName, Database)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to recover database file using current data")
} else {
logger.Warn().
Str("path", YAMLDatabasePath).
Msg("The database file is recovered using current data")
}
} else if reflect.DeepEqual(*savedDatabase, Database) {
// 没有修改就跳过保存
logger.Debug().Msg("Looks database no any change, skip autosave this time")
} else {
// 如果数据库文件中有设定专用的 `FORCEOVERWRITE: true` 覆写标记
// 无论任何修改,先保存程序中的数据,再读取新的数据替换掉当前的数据并保存
if savedDatabase.ForceOverwrite {
logger.Warn().
Str("path", YAMLDatabasePath).
Msg("Detected `FORCEOVERWRITE: true` in database file, save current database to another file first")
oldFileName := fmt.Sprintf("beforeOverwritten_%d_%s", time.Now().Unix(), configs.YAMLFileName)
oldFilePath := filepath.Join(configs.YAMLDatabaseDir, oldFileName)
err := SaveYamlDB(ctx, configs.YAMLDatabaseDir, oldFileName, savedDatabase)
if err != nil {
logger.Warn().
Err(err).
Str("oldPath", oldFilePath).
Msg("Failed to save the database before overwrite")
} else {
logger.Warn().
Err(err).
Str("oldPath", oldFilePath).
Msg("The database before overwrite is saved to another file")
}
Database = *savedDatabase
Database.ForceOverwrite = false // 移除强制覆盖标记
err = SaveYamlDB(ctx, configs.YAMLDatabaseDir, configs.YAMLFileName, Database)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to save the database after overwrite")
} else {
logger.Warn().
Str("path", YAMLDatabasePath).
Msg("Read new database file and save to the old file")
}
} else {
// 没有设定覆写标记,检测到本地的数据库更新时间比程序中的更新时间更晚
if savedDatabase.UpdateTimestamp >= Database.UpdateTimestamp {
logger.Warn().
Msg("The database file is newer than current data in the program")
// 如果只是更新时间有差别,更新一下时间,再保存就行
if reflect.DeepEqual(savedDatabase.Data, Database.Data) {
logger.Warn().
Msg("But current data and database is the same, updating UpdateTimestamp in the database only")
Database.UpdateTimestamp = time.Now().Unix()
err := SaveYamlDB(ctx, configs.YAMLDatabaseDir, configs.YAMLFileName, Database)
if err != nil {
logger.Error().
Err(err).
Msg("Failed to save database after updating UpdateTimestamp")
}
} else {
// 数据库文件与程序中的数据不同,提示不要在程序运行的时候乱动数据库文件
logger.Warn().
Str("notice", "Do not modify the database file while the program is running, If you want to overwrite the current database, please add the field `FORCEOVERWRITE: true` at the beginning of the file.").
Msg("The database file is different from the current database, saving modified file and recovering database file using current data in the program")
// 将新的数据文件改名另存为 `edited_时间戳_文件名`,再使用程序中的数据还原数据文件
editedFileName := fmt.Sprintf("edited_%d_%s", time.Now().Unix(), configs.YAMLFileName)
editedFilePath := filepath.Join(configs.YAMLDatabaseDir, editedFileName)
err := SaveYamlDB(ctx, configs.YAMLDatabaseDir, editedFileName, savedDatabase)
if err != nil {
logger.Error().
Err(err).
Str("editedPath", editedFilePath).
Msg("Failed to save modified database")
} else {
logger.Warn().
Str("editedPath", editedFilePath).
Msg("The modified database is saved to another file")
}
err = SaveYamlDB(ctx, configs.YAMLDatabaseDir, configs.YAMLFileName, Database)
if err != nil {
logger.Error().
Err(err).
Str("path", YAMLDatabasePath).
Msg("Failed to recover database file")
} else {
logger.Warn().
Str("path", YAMLDatabasePath).
Msg("The database file is recovered using current data in the program")
}
}
} else {
// 数据有更改,程序内的更新时间也比本地数据库晚,正常保存
// 无论如何都尽量不要手动修改数据库文件,如果必要也请在开头添加专用的 `FORCEOVERWRITE: true` 覆写标记,或停止程序后再修改
Database.UpdateTimestamp = time.Now().Unix()
err := SaveYamlDB(ctx, configs.YAMLDatabaseDir, configs.YAMLFileName, Database)
if err != nil {
logger.Error().
Err(err).
Msg("Failed to auto save database")
} else {
logger.Debug().
Str("path", YAMLDatabasePath).
Msg("The database is auto saved")
}
}
}
}
}
}
// 初次添加群组时,获取必要信息
func InitChat(ctx context.Context, chat *models.Chat) error {
for _, data := range Database.Data.ChatInfo {
if data.ID == chat.ID {
return nil // 群组已存在,不重复添加
}
}
Database.Data.ChatInfo = append(Database.Data.ChatInfo, db_struct.ChatInfo{
ID: chat.ID,
ChatType: chat.Type,
ChatName: utils.ShowChatName(chat),
AddTime: time.Now().Format(time.RFC3339),
})
return SaveDatabase(ctx)
}
func InitUser(ctx context.Context, user *models.User) error {
for _, data := range Database.Data.ChatInfo {
if data.ID == user.ID {
return nil // 用户已存在,不重复添加
}
}
Database.Data.ChatInfo = append(Database.Data.ChatInfo, db_struct.ChatInfo{
ID: user.ID,
ChatType: models.ChatTypePrivate,
ChatName: utils.ShowUserName(user),
AddTime: time.Now().Format(time.RFC3339),
})
return SaveDatabase(ctx)
}
// 获取 ID 信息
func GetChatInfo(ctx context.Context, id int64) (*db_struct.ChatInfo, error) {
for Index, Data := range Database.Data.ChatInfo {
if Data.ID == id {
return &Database.Data.ChatInfo[Index], nil
}
}
return nil, fmt.Errorf("ChatInfo not found")
}
func IncrementalUsageCount(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_UsageCount) error {
for Index, Data := range Database.Data.ChatInfo {
if Data.ID == chatID {
Database.UpdateTimestamp = time.Now().Unix() + 1
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Type().Field(i).Name == string(fieldName) {
v.Field(i).SetInt(v.Field(i).Int() + 1)
return nil
}
}
}
}
return fmt.Errorf("ChatInfo not found")
}
func RecordLatestData(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_LatestData, value string) error {
for Index, Data := range Database.Data.ChatInfo {
if Data.ID == chatID {
Database.UpdateTimestamp = time.Now().Unix() + 1
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Type().Field(i).Name == string(fieldName) {
v.Field(i).SetString(value)
return nil
}
}
}
}
return fmt.Errorf("ChatInfo not found")
}
func UpdateOperationStatus(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_Status, value bool) error {
for Index, Data := range Database.Data.ChatInfo {
if Data.ID == chatID {
Database.UpdateTimestamp = time.Now().Unix() + 1
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Type().Field(i).Name == string(fieldName) {
v.Field(i).SetBool(value)
return nil
}
}
}
}
return fmt.Errorf("ChatInfo not found")
}
func SetCustomFlag(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_CustomFlag, value string) error {
for Index, Data := range Database.Data.ChatInfo {
if Data.ID == chatID {
Database.UpdateTimestamp = time.Now().Unix() + 1
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Type().Field(i).Name == string(fieldName) {
v.Field(i).SetString(value)
return nil
}
}
}
}
return fmt.Errorf("ChatInfo not found")
}