79 lines
1.4 KiB
Go
79 lines
1.4 KiB
Go
|
package subscribe
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/redis/go-redis/v9"
|
||
|
)
|
||
|
|
||
|
type Server struct {
|
||
|
rdb *redis.Client
|
||
|
ctx context.Context
|
||
|
pubSub *redis.PubSub
|
||
|
// 给fm加锁
|
||
|
lock sync.RWMutex
|
||
|
fm map[string]MessageWorker
|
||
|
errHandle func(err error)
|
||
|
}
|
||
|
|
||
|
type MessageWorker func(ctx context.Context, message string) error
|
||
|
|
||
|
func NewServer(ctx context.Context, rdb *redis.Client) *Server {
|
||
|
return &Server{
|
||
|
ctx: ctx,
|
||
|
rdb: rdb,
|
||
|
fm: make(map[string]MessageWorker),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 订阅
|
||
|
func (s *Server) Subscribe(channel string, f MessageWorker) error {
|
||
|
if f == nil {
|
||
|
return fmt.Errorf("message worker is nil")
|
||
|
}
|
||
|
if s.pubSub == nil {
|
||
|
s.pubSub = s.rdb.Subscribe(s.ctx, channel)
|
||
|
} else if err := s.pubSub.Subscribe(s.ctx, channel); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
s.fm[channel] = f
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// 取消订阅
|
||
|
func (s *Server) Unsubscribe(channel string) error {
|
||
|
if s.pubSub == nil {
|
||
|
return nil
|
||
|
}
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
delete(s.fm, channel)
|
||
|
return s.pubSub.Unsubscribe(s.ctx, channel)
|
||
|
}
|
||
|
|
||
|
func (s *Server) Run() {
|
||
|
ch := s.pubSub.Channel()
|
||
|
defer s.pubSub.Close()
|
||
|
for {
|
||
|
select {
|
||
|
case <-s.ctx.Done():
|
||
|
return
|
||
|
case msg := <-ch:
|
||
|
s.lock.RLock()
|
||
|
f := s.fm[msg.Channel]
|
||
|
s.lock.RUnlock()
|
||
|
if f != nil {
|
||
|
if err := f(s.ctx, msg.Payload); err != nil {
|
||
|
s.errHandle(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|