package db import ( "errors" "gitlab.com/kedaya_mp/user/biz/model/common" "gorm.io/gorm" ) type Scoper interface { Scope(db *gorm.DB) *gorm.DB } type UserStore interface { Create(*common.User) (*common.User, error) Find(Scoper, ...string) (*common.User, error) Exist(Scoper) (bool, error) Update(*common.User, ...string) error AutoMigrate() error } type userStore struct { db *gorm.DB } func NewUserStore(db *gorm.DB) UserStore { return &userStore{db: db} } func (s *userStore) AutoMigrate() error { return s.db.AutoMigrate(&common.User{}) } type FindUserQuery struct { ID int64 Phone string WxUnionID string } func (f *FindUserQuery) Scope(db *gorm.DB) *gorm.DB { if f.ID != 0 { db = db.Where("id = ?", f.ID) } if f.Phone != "" { db = db.Where("phone = ?", f.Phone) } if f.WxUnionID != "" { db = db.Where("wx_union_id = ?", f.WxUnionID) } return db } // Create 创建用户 func (s *userStore) Create(u *common.User) (*common.User, error) { if err := s.db.Create(u).Error; err != nil { return nil, err } return u, nil } // Find 查找用户 func (s *userStore) Find(scoper Scoper, selects ...string) (*common.User, error) { var dbUser common.User db := s.db if len(selects) > 0 { db = db.Select(selects) } if err := scoper.Scope(db).First(&dbUser).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &dbUser, nil } // Exist 检查用户是否存在 func (s *userStore) Exist(scoper Scoper) (bool, error) { var count int64 if err := scoper.Scope(s.db).Model(&common.User{}).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // Update 更新用户信息 func (s *userStore) Update(u *common.User, selects ...string) error { return s.db.Transaction(func(tx *gorm.DB) error { if len(selects) > 0 { tx = tx.Select(selects) } return tx.Where("id =?", u.ID).Updates(u).Error }) }