package storage

import (
	v2 "gitea.timerzz.com/kedaya_haitao/common/structs/v2"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

// ProviderArticleApi 管理供应商商品的接口
type ProviderArticleApi interface {
	Get(query *GetProviderArticleQuery) (article v2.ProviderArticle, err error)
	Upsert(article v2.ProviderArticle) error
	Update(providerArticle v2.ProviderArticle, selects ...string) error
	BatchUpdate(articles []v2.ProviderArticle) error
	AutoMigrate() error
	Find(query *GetProviderArticleQuery) (articles []*v2.ProviderArticle, err error)
	List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.ProviderArticle, total int64, err error)
	FindInBatches(query *GetProviderArticleQuery, results *[]v2.ProviderArticle, f func(tx *gorm.DB, batch int) error) error
	ProviderPrice(providerArticleID uint) (history []v2.ProviderPrice, err error)
	UpdateStatus(article v2.ProviderArticle) error
}

type providerArticleApi struct {
	db *gorm.DB
}

func NewProviderArticleApi(db *gorm.DB) ProviderArticleApi {
	return &providerArticleApi{db: db}
}

// ******************GET

type GetProviderArticleQuery struct {
	ID              uint          `query:"id"`
	Brand           v2.Brand      `query:"brand"`
	Pid             string        `query:"pid"`
	ProviderId      v2.ProviderId `query:"providerId"`
	SkuId           string        `query:"skuId"`
	WatchNotNull    bool          `query:"watchNotNull"`
	Watch           *bool         `query:"watch"`
	TraceAtsNotNull bool          `query:"traceAtsNotNull"`
	TraceAts        *bool         `query:"traceAts"`
	Keyword         string        `query:"keyword"`
}

func NewGetProviderArticleQuery() *GetProviderArticleQuery {
	return &GetProviderArticleQuery{}
}

func (g *GetProviderArticleQuery) SetID(id uint) *GetProviderArticleQuery {
	g.ID = id
	return g
}
func (g *GetProviderArticleQuery) SetBrand(brand v2.Brand) *GetProviderArticleQuery {
	g.Brand = brand
	return g
}
func (g *GetProviderArticleQuery) SetPid(pid string) *GetProviderArticleQuery {
	g.Pid = pid
	return g
}

func (g *GetProviderArticleQuery) SetProviderId(providerId v2.ProviderId) *GetProviderArticleQuery {
	g.ProviderId = providerId
	return g
}

func (g *GetProviderArticleQuery) SetSkuId(skuId string) *GetProviderArticleQuery {
	g.SkuId = skuId
	return g
}

func (g *GetProviderArticleQuery) SetWatch(watch bool) *GetProviderArticleQuery {
	g.Watch = &watch
	return g
}

func (g *GetProviderArticleQuery) SetWatchNotNull(watchNotNull bool) *GetProviderArticleQuery {
	g.WatchNotNull = watchNotNull
	return g
}

func (g *GetProviderArticleQuery) SetKeyword(keyword string) *GetProviderArticleQuery {
	g.Keyword = keyword
	return g
}

func (g *GetProviderArticleQuery) SetTraceAts(traceAts bool) *GetProviderArticleQuery {
	g.TraceAts = &traceAts
	return g
}

func (g *GetProviderArticleQuery) SetTraceAtsNotNull(traceAtsNotNull bool) *GetProviderArticleQuery {
	g.TraceAtsNotNull = traceAtsNotNull
	return g
}

func (g *GetProviderArticleQuery) Scope(db *gorm.DB) *gorm.DB {
	if g.ID > 0 {
		db = db.Where("id=?", g.ID)
	}
	if g.Brand != "" {
		db = db.Where("brand=?", g.Brand)
	}
	if g.Pid != "" {
		db = db.Where("pid=?", g.Pid)
	}
	if g.ProviderId != "" {
		db = db.Where("provider_id=?", g.ProviderId)
	}
	if g.SkuId != "" {
		db = db.Where("sku_id=?", g.SkuId)
	}
	if g.WatchNotNull {
		db = db.Not("watch", nil)
	}
	if g.Watch != nil {
		db = db.Where("watch=?", *g.Watch)
	}
	if g.TraceAtsNotNull {
		db = db.Not("trace_ats", nil)
	}
	if g.TraceAts != nil {
		db = db.Where("trace_ats=?", *g.TraceAts)
	}
	if g.Keyword != "" {
		db = db.Where("sku_id ilike ?", "%"+g.Keyword+"%")
	}
	return db
}

func (p *providerArticleApi) Get(query *GetProviderArticleQuery) (article v2.ProviderArticle, err error) {
	err = p.db.Scopes(query.Scope).Preload("CalculateProcess").First(&article).Error
	return
}

func (p *providerArticleApi) Upsert(article v2.ProviderArticle) error {
	if err := p.db.Clauses(clause.OnConflict{
		Columns:   []clause.Column{{Name: "provider_id"}, {Name: "sku_id"}},
		DoUpdates: clause.AssignmentColumns([]string{"cost", "available", "updated_at", "ats"}),
	}).Create(&article).Error; err != nil {
		return err
	}
	if len(article.HistoryPrice) > 0 {
		return p.db.Save(&article.HistoryPrice).Error
	}
	return nil
}
func (p *providerArticleApi) UpdateStatus(article v2.ProviderArticle) error {
	return p.db.Where("id=?", article.ID).Select("status").Updates(&article).Error
}

func (p *providerArticleApi) FindInBatches(query *GetProviderArticleQuery, results *[]v2.ProviderArticle, f func(tx *gorm.DB, batch int) error) error {
	err := p.db.Scopes(query.Scope).Preload("CalculateProcess").FindInBatches(results, 20, f).Error
	return err
}

func (p *providerArticleApi) List(query PageListQuery, scopes ...func(db *gorm.DB) *gorm.DB) (articles []v2.ProviderArticle, total int64, err error) {
	err = p.db.Scopes(query.Scoper.Scope).Model(&v2.ProviderArticle{}).Count(&total).Error
	if err != nil {
		return
	}
	err = p.db.Scopes(query.Scope).Scopes(scopes...).Find(&articles).Error
	return
}

func (p *providerArticleApi) Find(query *GetProviderArticleQuery) (articles []*v2.ProviderArticle, err error) {
	err = p.db.Scopes(query.Scope).Find(&articles).Error
	return
}

// 批量更新,更新价格时用到
func (p *providerArticleApi) BatchUpdate(articles []v2.ProviderArticle) error {
	return p.db.Select("id", "cost").Save(&articles).Error
}

func (p *providerArticleApi) AutoMigrate() error {
	return p.db.AutoMigrate(&v2.ProviderArticle{}, &v2.ProviderPrice{}, &v2.ProviderAts{})
}

func (p *providerArticleApi) ProviderPrice(providerArticleID uint) (history []v2.ProviderPrice, err error) {
	err = p.db.Find(&history, "provider_article_id = ?", providerArticleID).Error
	return
}

func (p *providerArticleApi) Update(providerArticle v2.ProviderArticle, selects ...string) error {
	if len(selects) == 0 {
		selects = []string{"Exclude"}
	}
	return p.db.Model(&providerArticle).Select(selects).Omit(clause.Associations).Updates(&providerArticle).Error
}