diff --git a/structs/storage/article.go b/structs/storage/article.go index 31a9f58..375ed37 100644 --- a/structs/storage/article.go +++ b/structs/storage/article.go @@ -13,7 +13,7 @@ type ArticleApi interface { Upsert(article v2.Article) error Update(article v2.Article, selects ...string) error Find(query *FindArticleQuery) (articles []v2.Article, err error) - List(query PageListQuery) (articles []v2.Article, total int64, err error) + List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.Article, total int64, err error) Get(query *GetArticleQuery) (article v2.Article, err error) AutoMigrate() error } @@ -111,6 +111,10 @@ func (f *FindArticleQuery) Scope(db *gorm.DB) *gorm.DB { if f.Available != nil { db = db.Where("available=?", *f.Available) } + return db +} + +func (f *FindArticleQuery) SortScope(db *gorm.DB) *gorm.DB { if f.RateSort == "descend" { db = db.Order("rate desc") } else if f.RateSort == "ascend" { @@ -118,18 +122,17 @@ func (f *FindArticleQuery) Scope(db *gorm.DB) *gorm.DB { } return db } - func (a *articleApi) Find(query *FindArticleQuery) (articles []v2.Article, err error) { err = a.db.Scopes(query.Scope).Find(&articles).Error return } -func (a *articleApi) List(query PageListQuery) (articles []v2.Article, total int64, err error) { +func (a *articleApi) List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.Article, total int64, err error) { err = a.db.Scopes(query.Scoper.Scope).Model(&v2.Article{}).Count(&total).Error if err != nil { return } - err = a.db.Scopes(query.Scope).Order("id").Find(&articles).Error + err = a.db.Scopes(query.Scope).Scopes(scopes...).Order("id").Find(&articles).Error return }