common/pkg/subscribe/server.go
2024-08-29 20:38:48 +08:00

79 lines
1.4 KiB
Go

package subscribe
import (
"context"
"fmt"
"sync"
"github.com/redis/go-redis/v9"
)
type Server struct {
rdb *redis.Client
ctx context.Context
pubSub *redis.PubSub
// 给fm加锁
lock sync.RWMutex
fm map[string]MessageWorker
errHandle func(err error)
}
type MessageWorker func(ctx context.Context, message string) error
func NewServer(ctx context.Context, rdb *redis.Client) *Server {
return &Server{
ctx: ctx,
rdb: rdb,
fm: make(map[string]MessageWorker),
}
}
// 订阅
func (s *Server) Subscribe(channel string, f MessageWorker) error {
if f == nil {
return fmt.Errorf("message worker is nil")
}
if s.pubSub == nil {
s.pubSub = s.rdb.Subscribe(s.ctx, channel)
} else if err := s.pubSub.Subscribe(s.ctx, channel); err != nil {
return err
}
s.lock.Lock()
defer s.lock.Unlock()
s.fm[channel] = f
return nil
}
// 取消订阅
func (s *Server) Unsubscribe(channel string) error {
if s.pubSub == nil {
return nil
}
s.lock.Lock()
defer s.lock.Unlock()
delete(s.fm, channel)
return s.pubSub.Unsubscribe(s.ctx, channel)
}
func (s *Server) Run() {
ch := s.pubSub.Channel()
defer s.pubSub.Close()
for {
select {
case <-s.ctx.Done():
return
case msg := <-ch:
s.lock.RLock()
f := s.fm[msg.Channel]
s.lock.RUnlock()
if f != nil {
if err := f(s.ctx, msg.Payload); err != nil {
s.errHandle(err)
}
}
}
}
}