402 lines
12 KiB
Go
402 lines
12 KiB
Go
package yaml_db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"time"
|
|
"trbot/database/db_struct"
|
|
"trbot/utils"
|
|
"trbot/utils/consts"
|
|
"trbot/utils/yaml"
|
|
|
|
"github.com/go-telegram/bot/models"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
var Database DataBaseYaml
|
|
|
|
var YAMLDatabasePath = filepath.Join(consts.YAMLDataBaseDir, consts.YAMLFileName)
|
|
|
|
// 需要重构,错误信息不足
|
|
|
|
type DataBaseYaml struct {
|
|
// 如果运行中希望程序强制读取新数据,在 YAML 数据库文件的开头添加 FORCEOVERWRITE: true 即可
|
|
ForceOverwrite bool `yaml:"FORCEOVERWRITE,omitempty"`
|
|
|
|
UpdateTimestamp int64 `yaml:"UpdateTimestamp"`
|
|
Data struct {
|
|
ChatInfo []db_struct.ChatInfo `yaml:"ChatInfo"`
|
|
} `yaml:"Data"`
|
|
}
|
|
|
|
func InitializeDB(ctx context.Context) error {
|
|
if consts.YAMLDataBaseDir != "" {
|
|
err := ReadDatabase(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read yaml database: %s", err)
|
|
}
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("yaml database path is empty")
|
|
}
|
|
}
|
|
|
|
func SaveDatabase(ctx context.Context) error {
|
|
logger := zerolog.Ctx(ctx).
|
|
With().
|
|
Str("database", "yaml").
|
|
Str("funcName", "SaveDatabase").
|
|
Logger()
|
|
|
|
Database.UpdateTimestamp = time.Now().Unix()
|
|
err := yaml.SaveYAML(YAMLDatabasePath, &Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to save database")
|
|
return fmt.Errorf("failed to save database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ReadDatabase(ctx context.Context) error {
|
|
logger := zerolog.Ctx(ctx).
|
|
With().
|
|
Str("database", "yaml").
|
|
Str("funcName", "ReadDatabase").
|
|
Logger()
|
|
|
|
err := yaml.LoadYAML(YAMLDatabasePath, &Database)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
logger.Warn().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Not found database file. Created new one")
|
|
// 如果是找不到文件,新建一个
|
|
err = yaml.SaveYAML(YAMLDatabasePath, &Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to create empty database file")
|
|
return fmt.Errorf("failed to create empty database file: %w", err)
|
|
}
|
|
} else {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to read database file")
|
|
return fmt.Errorf("failed to read database file: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ReadYamlDB(ctx context.Context, pathToFile string) (*DataBaseYaml, error) {
|
|
logger := zerolog.Ctx(ctx).
|
|
With().
|
|
Str("database", "yaml").
|
|
Str("funcName", "ReadYamlDB").
|
|
Logger()
|
|
|
|
var tempDatabase *DataBaseYaml
|
|
err := yaml.LoadYAML(pathToFile, &tempDatabase)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
logger.Warn().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Not found database file. Created new one")
|
|
// 如果是找不到文件,新建一个
|
|
err = yaml.SaveYAML(YAMLDatabasePath, &tempDatabase)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to create empty database file")
|
|
return nil, fmt.Errorf("failed to create empty database file: %w", err)
|
|
}
|
|
} else {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to read database file")
|
|
return nil, fmt.Errorf("failed to read database file: %w", err)
|
|
}
|
|
}
|
|
|
|
return tempDatabase, nil
|
|
}
|
|
|
|
// 路径 文件名 YAML 数据结构体
|
|
func SaveYamlDB(ctx context.Context, path, name string, tempDatabase interface{}) error {
|
|
logger := zerolog.Ctx(ctx).
|
|
With().
|
|
Str("database", "yaml").
|
|
Str("funcName", "SaveDatabase").
|
|
Logger()
|
|
|
|
err := yaml.SaveYAML(filepath.Join(path, name), &tempDatabase)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to save database")
|
|
return fmt.Errorf("failed to save database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func AutoSaveDatabaseHandler(ctx context.Context) {
|
|
logger := zerolog.Ctx(ctx).
|
|
With().
|
|
Str("database", "yaml").
|
|
Str("funcName", "AutoSaveDatabaseHandler").
|
|
Logger()
|
|
|
|
// 先读取一下数据库文件
|
|
savedDatabase, err := ReadYamlDB(ctx, YAMLDatabasePath)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to read database file")
|
|
} else {
|
|
// 如果读取数据库文件时发现数据库为空,使用当前的数据重建数据库文件
|
|
if savedDatabase == nil {
|
|
logger.Warn().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("The database file is empty, recover database file using current data")
|
|
err = SaveYamlDB(ctx, consts.YAMLDataBaseDir, consts.YAMLFileName, Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to recover database file using current data")
|
|
} else {
|
|
logger.Warn().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("The database file is recovered using current data")
|
|
}
|
|
} else if reflect.DeepEqual(*savedDatabase, Database) {
|
|
// 没有修改就跳过保存
|
|
logger.Debug().Msg("Looks database no any change, skip autosave this time")
|
|
} else {
|
|
// 如果数据库文件中有设定专用的 `FORCEOVERWRITE: true` 覆写标记
|
|
// 无论任何修改,先保存程序中的数据,再读取新的数据替换掉当前的数据并保存
|
|
if savedDatabase.ForceOverwrite {
|
|
logger.Warn().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Detected `FORCEOVERWRITE: true` in database file, save current database to another file first")
|
|
|
|
oldFileName := fmt.Sprintf("beforeOverwritten_%d_%s", time.Now().Unix(), consts.YAMLFileName)
|
|
oldFilePath := filepath.Join(consts.YAMLDataBaseDir, oldFileName)
|
|
|
|
err := SaveYamlDB(ctx, consts.YAMLDataBaseDir, oldFileName, savedDatabase)
|
|
if err != nil {
|
|
logger.Warn().
|
|
Err(err).
|
|
Str("oldPath", oldFilePath).
|
|
Msg("Failed to save the database before overwrite")
|
|
} else {
|
|
logger.Warn().
|
|
Err(err).
|
|
Str("oldPath", oldFilePath).
|
|
Msg("The database before overwrite is saved to another file")
|
|
}
|
|
Database = *savedDatabase
|
|
Database.ForceOverwrite = false // 移除强制覆盖标记
|
|
err = SaveYamlDB(ctx, consts.YAMLDataBaseDir, consts.YAMLFileName, Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to save the database after overwrite")
|
|
} else {
|
|
logger.Warn().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Read new database file and save to the old file")
|
|
}
|
|
} else {
|
|
// 没有设定覆写标记,检测到本地的数据库更新时间比程序中的更新时间更晚
|
|
if savedDatabase.UpdateTimestamp >= Database.UpdateTimestamp {
|
|
logger.Warn().
|
|
Msg("The database file is newer than current data in the program")
|
|
// 如果只是更新时间有差别,更新一下时间,再保存就行
|
|
if reflect.DeepEqual(savedDatabase.Data, Database.Data) {
|
|
logger.Warn().
|
|
Msg("But current data and database is the same, updating UpdateTimestamp in the database only")
|
|
Database.UpdateTimestamp = time.Now().Unix()
|
|
err := SaveYamlDB(ctx, consts.YAMLDataBaseDir, consts.YAMLFileName, Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Msg("Failed to save database after updating UpdateTimestamp")
|
|
}
|
|
} else {
|
|
// 数据库文件与程序中的数据不同,提示不要在程序运行的时候乱动数据库文件
|
|
logger.Warn().
|
|
Str("notice", "Do not modify the database file while the program is running, If you want to overwrite the current database, please add the field `FORCEOVERWRITE: true` at the beginning of the file.").
|
|
Msg("The database file is different from the current database, saving modified file and recovering database file using current data in the program")
|
|
|
|
// 将新的数据文件改名另存为 `edited_时间戳_文件名`,再使用程序中的数据还原数据文件
|
|
editedFileName := fmt.Sprintf("edited_%d_%s", time.Now().Unix(), consts.YAMLFileName)
|
|
editedFilePath := filepath.Join(consts.YAMLDataBaseDir, editedFileName)
|
|
|
|
err := SaveYamlDB(ctx, consts.YAMLDataBaseDir, editedFileName, savedDatabase)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("editedPath", editedFilePath).
|
|
Msg("Failed to save modified database")
|
|
} else {
|
|
logger.Warn().
|
|
Str("editedPath", editedFilePath).
|
|
Msg("The modified database is saved to another file")
|
|
}
|
|
err = SaveYamlDB(ctx, consts.YAMLDataBaseDir, consts.YAMLFileName, Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("Failed to recover database file")
|
|
} else {
|
|
logger.Warn().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("The database file is recovered using current data in the program")
|
|
}
|
|
}
|
|
} else {
|
|
// 数据有更改,程序内的更新时间也比本地数据库晚,正常保存
|
|
// 无论如何都尽量不要手动修改数据库文件,如果必要也请在开头添加专用的 `FORCEOVERWRITE: true` 覆写标记,或停止程序后再修改
|
|
Database.UpdateTimestamp = time.Now().Unix()
|
|
err := SaveYamlDB(ctx, consts.YAMLDataBaseDir, consts.YAMLFileName, Database)
|
|
if err != nil {
|
|
logger.Error().
|
|
Err(err).
|
|
Msg("Failed to auto save database")
|
|
} else {
|
|
logger.Debug().
|
|
Str("path", YAMLDatabasePath).
|
|
Msg("The database is auto saved")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// 初次添加群组时,获取必要信息
|
|
func InitChat(ctx context.Context, chat *models.Chat) error {
|
|
for _, data := range Database.Data.ChatInfo {
|
|
if data.ID == chat.ID {
|
|
return nil // 群组已存在,不重复添加
|
|
}
|
|
}
|
|
Database.Data.ChatInfo = append(Database.Data.ChatInfo, db_struct.ChatInfo{
|
|
ID: chat.ID,
|
|
ChatType: chat.Type,
|
|
ChatName: utils.ShowChatName(chat),
|
|
AddTime: time.Now().Format(time.RFC3339),
|
|
})
|
|
return SaveDatabase(ctx)
|
|
}
|
|
|
|
func InitUser(ctx context.Context, user *models.User) error {
|
|
for _, data := range Database.Data.ChatInfo {
|
|
if data.ID == user.ID {
|
|
return nil // 用户已存在,不重复添加
|
|
}
|
|
}
|
|
Database.Data.ChatInfo = append(Database.Data.ChatInfo, db_struct.ChatInfo{
|
|
ID: user.ID,
|
|
ChatType: models.ChatTypePrivate,
|
|
ChatName: utils.ShowUserName(user),
|
|
AddTime: time.Now().Format(time.RFC3339),
|
|
})
|
|
return SaveDatabase(ctx)
|
|
}
|
|
|
|
// 获取 ID 信息
|
|
func GetChatInfo(ctx context.Context, id int64) (*db_struct.ChatInfo, error) {
|
|
for Index, Data := range Database.Data.ChatInfo {
|
|
if Data.ID == id {
|
|
return &Database.Data.ChatInfo[Index], nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("ChatInfo not found")
|
|
}
|
|
|
|
func IncrementalUsageCount(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_UsageCount) error {
|
|
for Index, Data := range Database.Data.ChatInfo {
|
|
if Data.ID == chatID {
|
|
Database.UpdateTimestamp = time.Now().Unix() + 1
|
|
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
if v.Type().Field(i).Name == string(fieldName) {
|
|
v.Field(i).SetInt(v.Field(i).Int() + 1)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("ChatInfo not found")
|
|
}
|
|
|
|
func RecordLatestData(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_LatestData, value string) error {
|
|
for Index, Data := range Database.Data.ChatInfo {
|
|
if Data.ID == chatID {
|
|
Database.UpdateTimestamp = time.Now().Unix() + 1
|
|
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
if v.Type().Field(i).Name == string(fieldName) {
|
|
v.Field(i).SetString(value)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("ChatInfo not found")
|
|
}
|
|
|
|
func UpdateOperationStatus(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_Status, value bool) error {
|
|
for Index, Data := range Database.Data.ChatInfo {
|
|
if Data.ID == chatID {
|
|
Database.UpdateTimestamp = time.Now().Unix() + 1
|
|
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
if v.Type().Field(i).Name == string(fieldName) {
|
|
v.Field(i).SetBool(value)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("ChatInfo not found")
|
|
}
|
|
|
|
func SetCustomFlag(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_CustomFlag, value string) error {
|
|
for Index, Data := range Database.Data.ChatInfo {
|
|
if Data.ID == chatID {
|
|
Database.UpdateTimestamp = time.Now().Unix() + 1
|
|
v := reflect.ValueOf(&Database.Data.ChatInfo[Index]).Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
if v.Type().Field(i).Name == string(fieldName) {
|
|
v.Field(i).SetString(value)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("ChatInfo not found")
|
|
}
|