181 lines
4.0 KiB
Go
181 lines
4.0 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"trbot/database/db_struct"
|
|
|
|
"github.com/go-telegram/bot/models"
|
|
)
|
|
|
|
// 需要给一些函数加上一个 success 返回值,有时部分数据库不可用,但数据成功保存到了其他数据库
|
|
|
|
func InitChat(ctx context.Context, chat *models.Chat) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.InitChat(ctx, chat)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.InitChat(ctx, chat)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func InitUser(ctx context.Context, user *models.User) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.InitUser(ctx, user)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.InitUser(ctx, user)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func GetChatInfo(ctx context.Context, chatID int64) (data *db_struct.ChatInfo, err error) {
|
|
// 优先从高优先级数据库获取数据
|
|
for _, db := range DBBackends {
|
|
data, err = db.GetChatInfo(ctx, chatID)
|
|
if err == nil {
|
|
return
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
data, err = db.GetChatInfo(ctx, chatID)
|
|
if err == nil {
|
|
return
|
|
}
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
return nil, fmt.Errorf("no database available")
|
|
}
|
|
|
|
func IncrementalUsageCount(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_UsageCount) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.IncrementalUsageCount(ctx, chatID, fieldName)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.IncrementalUsageCount(ctx, chatID, fieldName)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func RecordLatestData(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_LatestData, data string) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.RecordLatestData(ctx, chatID, fieldName, data)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.RecordLatestData(ctx, chatID, fieldName, data)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func UpdateOperationStatus(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_Status, value bool) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.UpdateOperationStatus(ctx, chatID, fieldName, value)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.UpdateOperationStatus(ctx, chatID, fieldName, value)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func SetCustomFlag(ctx context.Context, chatID int64, fieldName db_struct.ChatInfoField_CustomFlag, value string) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
err := db.SetCustomFlag(ctx, chatID, fieldName, value)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
err := db.SetCustomFlag(ctx, chatID, fieldName, value)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func SaveDatabase(ctx context.Context) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
if db.SaveDatabase == nil {
|
|
continue
|
|
}
|
|
err := db.SaveDatabase(ctx)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
if db.SaveDatabase == nil {
|
|
continue
|
|
}
|
|
err := db.SaveDatabase(ctx)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|
|
|
|
func ReadDatabase(ctx context.Context) error {
|
|
var allErr error
|
|
for _, db := range DBBackends {
|
|
if db.ReadDatabase == nil {
|
|
continue
|
|
}
|
|
err := db.ReadDatabase(ctx)
|
|
if err != nil {
|
|
allErr = err
|
|
}
|
|
}
|
|
for _, db := range DBBackends_LowLevel {
|
|
if db.ReadDatabase == nil {
|
|
continue
|
|
}
|
|
err := db.ReadDatabase(ctx)
|
|
if err != nil {
|
|
allErr = fmt.Errorf("%s, %s", allErr, err)
|
|
}
|
|
}
|
|
return allErr
|
|
}
|