package storage import ( "fmt" v2 "gitea.timerzz.com/kedaya_haitao/common/structs/v2" "gorm.io/gorm" "gorm.io/gorm/clause" ) type ArticleApi interface { Create(article *v2.Article) error Upsert(article v2.Article) error Update(article v2.Article, selects ...string) error Find(query *FindArticleQuery) (articles []v2.Article, err error) List(query PageListQuery) (articles []v2.Article, total int64, err error) Get(query *GetArticleQuery) (article v2.Article, err error) AutoMigrate() error } type articleApi struct { db *gorm.DB } func NewArticleApi(db *gorm.DB) ArticleApi { return &articleApi{ db: db, } } func (a *articleApi) Create(article *v2.Article) error { return a.db.Create(article).Error } // Upsert 插入或者更新商品 func (a *articleApi) Upsert(article v2.Article) error { return a.db.Transaction(func(tx *gorm.DB) error { if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "pid"}, {Name: "brand"}}, DoUpdates: clause.AssignmentColumns([]string{"available", "updated_at", "cost_price", "sell_price", "rate", "remark", "exclude"}), }).Create(&article).Error; err != nil { return err } if len(article.Providers) > 0 { if err := tx.Save(&article.Providers).Error; err != nil { return fmt.Errorf("failed to save providers: %v", err) } } if len(article.Sellers) > 0 { return tx.Save(&article.Sellers).Error } return nil }) } func (a *articleApi) Update(article v2.Article, selects ...string) error { if len(selects) == 0 { selects = []string{"Remark", "Exclude"} } return a.db.Model(&article).Select(selects).Omit(clause.Associations).Updates(article).Error } // ******************Find和List type FindArticleQuery struct { ID uint `query:"id"` Keyword string `query:"keyword"` Brand v2.Brand `query:"brand"` Pid string `query:"pid"` Available *bool `query:"available"` } func NewFindArticleQuery() *FindArticleQuery { return &FindArticleQuery{} } func (f *FindArticleQuery) SetID(id uint) *FindArticleQuery { f.ID = id return f } func (f *FindArticleQuery) SetKeyword(keyword string) *FindArticleQuery { f.Keyword = keyword return f } func (f *FindArticleQuery) SetBrand(brand v2.Brand) *FindArticleQuery { f.Brand = brand return f } func (f *FindArticleQuery) SetPid(pid string) *FindArticleQuery { f.Pid = pid return f } func (f *FindArticleQuery) SetAvailable(available bool) *FindArticleQuery { f.Available = &available return f } func (l *FindArticleQuery) Scope(db *gorm.DB) *gorm.DB { if l.ID != 0 { db = db.Where("id=?", l.ID) } if l.Keyword != "" { db = db.Where("(name ilike ? OR english_name ilike ? OR remark ilike ? )", "%"+l.Keyword+"%", "%"+l.Keyword+"%", "%"+l.Keyword+"%") } if l.Brand != "" { db = db.Where("brand=?", l.Brand) } if l.Pid != "" { db = db.Where("pid=?", l.Pid) } if l.Available != nil { db = db.Where("available=?", *l.Available) } return db } func (a *articleApi) Find(query *FindArticleQuery) (articles []v2.Article, err error) { err = a.db.Scopes(query.Scope).Find(&articles).Error return } func (a *articleApi) List(query PageListQuery) (articles []v2.Article, total int64, err error) { err = a.db.Scopes(query.Scoper.Scope).Model(&v2.Article{}).Count(&total).Error if err != nil { return } err = a.db.Scopes(query.Scope).Order("id").Find(&articles).Error return } //**************GET type GetArticleQuery struct { ID uint `query:"id"` Brand v2.Brand `query:"brand"` Pid string `query:"pid"` History bool `query:"history"` } func (g *GetArticleQuery) SetID(id uint) *GetArticleQuery { g.ID = id return g } func (g *GetArticleQuery) SetBrand(brand v2.Brand) *GetArticleQuery { g.Brand = brand return g } func (g *GetArticleQuery) SetPid(pid string) *GetArticleQuery { g.Pid = pid return g } func (g *GetArticleQuery) SetHistory(history bool) *GetArticleQuery { g.History = history return g } func NewGetArticleQuery() *GetArticleQuery { return &GetArticleQuery{} } func (g *GetArticleQuery) Scope(db *gorm.DB) *gorm.DB { db = db.Preload("Providers").Preload("Providers.CalculateProcess").Preload("Sellers").Preload("Sellers.CalculateProcess") if g.ID != 0 { db = db.Where("id=?", g.ID) } if g.Brand != "" { db = db.Where("brand=?", g.Brand) } if g.Pid != "" { db = db.Where("pid=?", g.Pid) } if g.History { db = db.Preload("Providers.HistoryPrice").Preload("Sellers.HistoryPrice") } return db } func (a *articleApi) Get(query *GetArticleQuery) (article v2.Article, err error) { err = a.db.Scopes(query.Scope).First(&article).Error return } func (a *articleApi) AutoMigrate() error { return a.db.AutoMigrate(&v2.Article{}) }