watcher/pkg/pusher/controller.go
2024-04-10 17:36:56 +08:00

125 lines
2.7 KiB
Go

package pusher
import (
"context"
"fmt"
"github.com/golang/glog"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"haitao_watcher/pkg/model"
"haitao_watcher/pkg/options"
)
type Pusher interface {
Push(title, content string) error
}
type Controller struct {
m map[uint]Pusher
db *gorm.DB
ctx context.Context
}
func NewController(ctx context.Context, db *gorm.DB) *Controller {
ctl := &Controller{
m: make(map[uint]Pusher),
db: db,
ctx: ctx,
}
ctl.migrateTables()
go func() {
if err := ctl.initPushers(); err != nil {
glog.Errorf("pusher init :%v", err)
}
}()
return ctl
}
func (c *Controller) Consume(ch <-chan model.PushMsg) {
go func() {
for {
select {
case <-c.ctx.Done():
return
case msg := <-ch:
for _, pusherId := range msg.ToPusher {
if pusher, ok := c.m[pusherId]; ok {
if err := pusher.Push(msg.Title, msg.Content); err != nil {
glog.Errorf("pusher %d err: %v", pusherId, err)
}
}
}
}
}
}()
}
func (c *Controller) migrateTables() {
tables := []schema.Tabler{
&model.Pusher[options.AnPushOption]{},
}
for _, table := range tables {
if err := c.db.AutoMigrate(table); err != nil {
glog.Fatalf("failed to migrate table %s: %v", table.TableName(), err)
}
}
}
func (c *Controller) initPushers() error {
var list []model.Pusher[*options.AnPushOption]
if err := c.db.Find(&list).Error; err != nil {
return err
}
for _, p := range list {
c.m[p.ID] = NewAnPush(p.Option)
}
return nil
}
func (c *Controller) AddPusher(opt *model.Pusher[options.AnPushOption]) error {
return c.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(opt).Error; err != nil {
return err
}
fmt.Println("id", opt.ID)
return nil
})
}
type ListPusherInfoRequest struct {
Keyword string `query:"keyword,omitempty"`
Page int `query:"page"`
Size int `query:"size"`
All bool `query:"all"`
}
func (c *Controller) List(req ListPusherInfoRequest) (resp model.ListResponse[model.Pusher[*options.AnPushOption]], err error) {
tx := c.db
if req.Keyword != "" {
tx = tx.Where("name LIKE ? or remark LIKE ?", fmt.Sprintf("%%%s%%", req.Keyword), fmt.Sprintf("%%%s%%", req.Keyword))
}
if err = tx.Model(&model.Pusher[*options.AnPushOption]{}).Find(&resp.List).Error; err != nil {
return resp, fmt.Errorf("查询总数失败:%v", err)
}
resp.Total = int64(len(resp.List))
if req.All || resp.Total == 0 {
return
}
// 查询列表
if req.Page < 1 {
req.Page = 1
}
if req.Size < 1 {
req.Size = 10
}
offset := (req.Page - 1) * req.Size
if err = tx.Order("created_at desc").Limit(req.Size).Offset(offset).
Find(&resp.List).Error; err != nil {
return resp, fmt.Errorf("查询列表失败:%v", err)
}
return
}