package storage import ( v2 "gitea.timerzz.com/kedaya_haitao/common/structs/v2" "gorm.io/gorm" "gorm.io/gorm/clause" ) // ProviderArticleApi 管理供应商商品的接口 type ProviderArticleApi interface { Get(query *GetProviderArticleQuery) (article v2.ProviderArticle, err error) Upsert(article v2.ProviderArticle) error Update(providerArticle v2.ProviderArticle, selects ...string) error BatchUpdate(articles []v2.ProviderArticle) error AutoMigrate() error Find(query *GetProviderArticleQuery) (articles []v2.ProviderArticle, err error) List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.ProviderArticle, total int64, err error) FindInBatches(query *GetProviderArticleQuery, results *[]v2.ProviderArticle, f func(tx *gorm.DB, batch int) error) error ProviderPrice(providerArticleID uint) (history []v2.ProviderPrice, err error) UpdateStatus(article v2.ProviderArticle) error } type providerArticleApi struct { db *gorm.DB } func NewProviderArticleApi(db *gorm.DB) ProviderArticleApi { return &providerArticleApi{db: db} } // ******************GET type GetProviderArticleQuery struct { ID uint `query:"id"` Brand v2.Brand `query:"brand"` Pid string `query:"pid"` ProviderId v2.ProviderId `query:"providerId"` SkuId string `query:"skuId"` WatchNotNull bool `query:"watchNotNull"` Watch *bool `query:"watch"` } func NewGetProviderArticleQuery() *GetProviderArticleQuery { return &GetProviderArticleQuery{} } func (g *GetProviderArticleQuery) SetID(id uint) *GetProviderArticleQuery { g.ID = id return g } func (g *GetProviderArticleQuery) SetBrand(brand v2.Brand) *GetProviderArticleQuery { g.Brand = brand return g } func (g *GetProviderArticleQuery) SetPid(pid string) *GetProviderArticleQuery { g.Pid = pid return g } func (g *GetProviderArticleQuery) SetProviderId(providerId v2.ProviderId) *GetProviderArticleQuery { g.ProviderId = providerId return g } func (g *GetProviderArticleQuery) SetSkuId(skuId string) *GetProviderArticleQuery { g.SkuId = skuId return g } func (g *GetProviderArticleQuery) SetWatch(watch bool) *GetProviderArticleQuery { g.Watch = &watch return g } func (g *GetProviderArticleQuery) SetWatchNotNull(watchNotNull bool) *GetProviderArticleQuery { g.WatchNotNull = watchNotNull return g } func (g *GetProviderArticleQuery) Scope(db *gorm.DB) *gorm.DB { if g.ID > 0 { db = db.Where("id=?", g.ID) } if g.Brand != "" { db = db.Where("brand=?", g.Brand) } if g.Pid != "" { db = db.Where("pid=?", g.Pid) } if g.ProviderId != "" { db = db.Where("provider_id=?", g.ProviderId) } if g.SkuId != "" { db = db.Where("sku_id=?", g.SkuId) } if g.WatchNotNull { db = db.Not("watch", nil) } if g.Watch != nil { db = db.Where("watch=?", *g.Watch) } return db } func (p *providerArticleApi) Get(query *GetProviderArticleQuery) (article v2.ProviderArticle, err error) { err = p.db.Scopes(query.Scope).Preload("CalculateProcess").First(&article).Error return } func (p *providerArticleApi) Upsert(article v2.ProviderArticle) error { if err := p.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "provider_id"}, {Name: "sku_id"}}, DoUpdates: clause.AssignmentColumns([]string{"cost", "available", "updated_at", "ast"}), }).Create(&article).Error; err != nil { return err } if len(article.HistoryPrice) > 0 { return p.db.Save(&article.HistoryPrice).Error } return nil } func (p *providerArticleApi) UpdateStatus(article v2.ProviderArticle) error { return p.db.Where("id=?", article.ID).Select("status").Updates(&article).Error } func (p *providerArticleApi) FindInBatches(query *GetProviderArticleQuery, results *[]v2.ProviderArticle, f func(tx *gorm.DB, batch int) error) error { err := p.db.Scopes(query.Scope).Preload("CalculateProcess").FindInBatches(results, 20, f).Error return err } func (p *providerArticleApi) List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.ProviderArticle, total int64, err error) { err = p.db.Scopes(query.Scoper.Scope).Model(&v2.ProviderArticle{}).Count(&total).Error if err != nil { return } err = p.db.Scopes(query.Scope).Scopes(scopes...).Find(&articles).Error return } func (p *providerArticleApi) Find(query *GetProviderArticleQuery) (articles []v2.ProviderArticle, err error) { err = p.db.Scopes(query.Scope).Find(&articles).Error return } // 批量更新,更新价格时用到 func (p *providerArticleApi) BatchUpdate(articles []v2.ProviderArticle) error { return p.db.Select("id", "cost").Save(&articles).Error } func (p *providerArticleApi) AutoMigrate() error { return p.db.AutoMigrate(&v2.ProviderArticle{}, &v2.ProviderPrice{}) } func (p *providerArticleApi) ProviderPrice(providerArticleID uint) (history []v2.ProviderPrice, err error) { err = p.db.Find(&history, "provider_article_id = ?", providerArticleID).Error return } func (p *providerArticleApi) Update(providerArticle v2.ProviderArticle, selects ...string) error { if len(selects) == 0 { selects = []string{"Exclude"} } return p.db.Model(&providerArticle).Select(selects).Omit(clause.Associations).Updates(&providerArticle).Error }