package subscribe

import (
	"context"
	"fmt"
	"sync"

	"github.com/redis/go-redis/v9"
)

type Server struct {
	rdb    *redis.Client
	ctx    context.Context
	pubSub *redis.PubSub
	// 给fm加锁
	lock      sync.RWMutex
	fm        map[string]MessageWorker
	errHandle func(err error)
}

type MessageWorker func(ctx context.Context, message string) error

func NewServer(ctx context.Context, rdb *redis.Client) *Server {
	return &Server{
		ctx: ctx,
		rdb: rdb,
		fm:  make(map[string]MessageWorker),
	}
}

// 订阅
func (s *Server) Subscribe(channel string, f MessageWorker) error {
	if f == nil {
		return fmt.Errorf("message worker is nil")
	}
	if s.pubSub == nil {
		s.pubSub = s.rdb.Subscribe(s.ctx, channel)
	} else if err := s.pubSub.Subscribe(s.ctx, channel); err != nil {
		return err
	}
	s.lock.Lock()
	defer s.lock.Unlock()

	s.fm[channel] = f
	return nil
}

// 取消订阅
func (s *Server) Unsubscribe(channel string) error {
	if s.pubSub == nil {
		return nil
	}
	s.lock.Lock()
	defer s.lock.Unlock()

	delete(s.fm, channel)
	return s.pubSub.Unsubscribe(s.ctx, channel)
}

func (s *Server) SetErrorHandle(f func(err error)) *Server {
	s.errHandle = f
	return s
}

func (s *Server) Run() {
	ch := s.pubSub.Channel()
	defer s.pubSub.Close()
	for {
		select {
		case <-s.ctx.Done():
			return
		case msg := <-ch:
			s.lock.RLock()
			f := s.fm[msg.Channel]
			s.lock.RUnlock()
			if f != nil {
				if err := f(s.ctx, msg.Payload); err != nil {
					s.errHandle(err)
				}
			}
		}
	}
}