Skip to content
Snippets Groups Projects
Unverified Commit b59c7d72 authored by kirinzhong's avatar kirinzhong Committed by GitHub
Browse files

feat: 优化 Repository 接口设计 (#13)

1. Repository 接口名优化:Create -> Add; Delete -> Remove
2. Repostory Add 和 Remove 方法不直接返回错误,改为执行完 Main 后统一返回
3. 按照新的框架设计,修改了 example
parent 569d4354
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,7 @@ type ICommand interface {
// ICommandMain Command 的业务逻辑,对应 Build + Act 方法
type ICommandMain interface {
Main(ctx context.Context, repo Repository) (err error)
Main(ctx context.Context, repo *Repository) (err error)
}
// ICommandInit Command 的初始化方法,会在锁和事务之前执行,可进行数据校验,前置准备工作
......
......@@ -81,14 +81,13 @@ func TestCommandMain(t *testing.T) {
var id string
res := engine.NewStage().Main(func(ctx context.Context, repo Repository) error {
res := engine.NewStage().Main(func(ctx context.Context, repo *Repository) error {
// 创建
o := &order{
Title: "testCreate",
}
if err := repo.Create(o); err != nil {
return err
}
repo.Add(o)
if err := repo.Save(ctx); err != nil {
return err
}
......@@ -100,7 +99,7 @@ func TestCommandMain(t *testing.T) {
assert.NoError(t, res.Error)
assert.Contains(t, db.Data, id)
res = engine.NewStage().Main(func(ctx context.Context, repo Repository) error {
res = engine.NewStage().Main(func(ctx context.Context, repo *Repository) error {
o := &order{BaseEntity: NewBase(id)}
if err := repo.Get(ctx, o); err != nil {
return err
......@@ -126,9 +125,10 @@ func TestCommandMain(t *testing.T) {
po := db.Data[id]
assert.Equal(t, "update_2", po.Name)
res = engine.NewStage().Main(func(ctx context.Context, repo Repository) error {
res = engine.NewStage().Main(func(ctx context.Context, repo *Repository) error {
o := &order{BaseEntity: NewBase(id)}
return repo.Delete(o)
repo.Remove(o)
return nil
}).Save(ctx)
assert.NoError(t, res.Error)
......
......@@ -182,6 +182,15 @@ func (h *RootContainer) Remove(root IEntity) {
// Repository 聚合根实体仓库
type Repository struct {
stage *Stage
errs []error
}
func (r *Repository) appendError(e error) {
r.errs = append(r.errs, e)
}
func (r *Repository) getError() error {
return errors.Join(r.errs...)
}
// Get 查询并构建聚合根
......@@ -202,8 +211,8 @@ func (r *Repository) Get(ctx context.Context, root IEntity, children ...interfac
return r.stage.updateSnapshot()
}
// GetManual 自定义函数获取聚合根实体,并添加到快照
func (r *Repository) GetManual(ctx context.Context, getter func(ctx context.Context, root ...IEntity), roots ...IEntity) error {
// CustomGet 自定义函数获取聚合根实体,并添加到快照
func (r *Repository) CustomGet(ctx context.Context, getter func(ctx context.Context, root ...IEntity), roots ...IEntity) error {
getter(ctx, roots...)
for _, root := range roots {
......@@ -218,21 +227,22 @@ func (r *Repository) GetManual(ctx context.Context, getter func(ctx context.Cont
return r.stage.updateSnapshot()
}
// Create 创建聚合根
func (r *Repository) Create(roots ...IEntity) error {
// Add 添加聚合根到仓库
func (r *Repository) Add(roots ...IEntity) {
for _, root := range roots {
if r.stage.hasSnapshot(root) {
return fmt.Errorf("root must be a new entity")
r.appendError(fmt.Errorf("root must be a new entity"))
return
}
if err := r.stage.meta.Add(root); err != nil {
return err
r.appendError(err)
return
}
}
return nil
}
// Delete 删除聚合根,root.GetID 不能为空
func (r *Repository) Delete(roots ...IEntity) error {
// Remove 删除聚合根,root.GetID 不能为空
func (r *Repository) Remove(roots ...IEntity) {
toCreate := make([]IEntity, 0)
for _, root := range roots {
if !r.stage.meta.Has(root) {
......@@ -240,30 +250,32 @@ func (r *Repository) Delete(roots ...IEntity) error {
}
}
if len(toCreate) > 0 {
if err := r.Create(toCreate...); err != nil {
return err
}
r.Add(toCreate...)
if err := r.stage.updateSnapshot(); err != nil {
return err
r.appendError(err)
return
}
}
for _, root := range roots {
if err := r.stage.meta.Remove(root); err != nil {
return err
r.appendError(err)
return
}
}
return nil
}
// Save 执行一次保存,并刷新快照
func (r *Repository) Save(ctx context.Context) error {
if err := r.getError(); err != nil {
return err
}
return r.stage.commit(ctx)
}
type BuildFunc func(ctx context.Context, h DomainBuilder) (roots []IEntity, err error)
type ActFunc func(ctx context.Context, container RootContainer, roots ...IEntity) error
type MainFunc func(ctx context.Context, repo Repository) error
type MainFunc func(ctx context.Context, repo *Repository) error
type PostSaveFunc func(ctx context.Context, res *Result)
// EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型,返回值是 ICommand
......@@ -427,14 +439,16 @@ func (e *Engine) NewStage() *Stage {
}
func (e *Engine) Create(ctx context.Context, roots ...IEntity) *Result {
return e.NewStage().Main(func(ctx context.Context, repo Repository) error {
return repo.Create(roots...)
return e.NewStage().Main(func(ctx context.Context, repo *Repository) error {
repo.Add(roots...)
return nil
}).Save(ctx)
}
func (e *Engine) Delete(ctx context.Context, roots ...IEntity) *Result {
return e.NewStage().Main(func(ctx context.Context, repo Repository) error {
return repo.Delete(roots...)
return e.NewStage().Main(func(ctx context.Context, repo *Repository) error {
repo.Remove(roots...)
return nil
}).Save(ctx)
}
......@@ -703,7 +717,7 @@ func (e *Stage) runCommand(ctx context.Context, c ICommand) *Result {
if err != nil {
return ResultErrOrBreak(err)
}
return e.WithOption(PostSaveOption(c.PostSave)).Lock(keys...).Main(func(ctx context.Context, repo Repository) error {
return e.WithOption(PostSaveOption(c.PostSave)).Lock(keys...).Main(func(ctx context.Context, repo *Repository) error {
buildRoots, err := c.Build(ctx, DomainBuilder{stage: repo.stage})
if err != nil {
return err
......@@ -752,7 +766,7 @@ func (e *Stage) Run(ctx context.Context, cmd interface{}) *Result {
options = append(options, PostSaveOption(cmdPostSave.PostSave))
}
return e.WithOption(options...).Lock(keys...).Main(c.Main).Save(ctx)
case func(ctx context.Context, repo Repository) error:
case func(ctx context.Context, repo *Repository) error:
return e.Main(c).Save(ctx)
}
panic("cmd is invalid")
......@@ -1078,9 +1092,13 @@ func (e *Stage) do(ctx context.Context) *Result {
// 创建聚合
var err error
if e.main != nil {
if err := e.main(ctx, Repository{stage: e}); err != nil {
repo := &Repository{stage: e}
if err := e.main(ctx, repo); err != nil {
return ResultErrOrBreak(err)
}
if err := repo.getError(); err != nil {
return ResultError(err)
}
}
err = e.persist(ctx)
......
......@@ -205,7 +205,7 @@ func TestRootContainer(t *testing.T) {
func TestBuildError(t *testing.T) {
ctx := context.Background()
res := NewEngine(nil, nil, WithoutTransaction).NewStage().Main(func(ctx context.Context, repo Repository) error {
res := NewEngine(nil, nil, WithoutTransaction).NewStage().Main(func(ctx context.Context, repo *Repository) error {
return fmt.Errorf("test")
}).Save(ctx)
assert.Error(t, res.Error)
......@@ -266,8 +266,8 @@ func TestEntityMove(t *testing.T) {
engine := NewEngine(testsuit.NewMemLock(), &MapExecutor{DB: &db})
engine.Create(ctx, testOrder1, testOrder2)
res := engine.NewStage().Main(func(ctx context.Context, repo Repository) error {
if err := repo.GetManual(ctx, func(ctx context.Context, root ...IEntity) {}, testOrder1, testOrder2); err != nil {
res := engine.NewStage().Main(func(ctx context.Context, repo *Repository) error {
if err := repo.CustomGet(ctx, func(ctx context.Context, root ...IEntity) {}, testOrder1, testOrder2); err != nil {
return err
}
testOrder1.Products = nil
......@@ -443,12 +443,13 @@ func TestEventPersist(t *testing.T) {
ID: event.ID,
Name: string(event.Type),
}, nil
})).Run(ctx, func(ctx context.Context, repo Repository) error {
})).Run(ctx, func(ctx context.Context, repo *Repository) error {
testOrder = &order{
Title: "order1",
}
testOrder.AddEvent(&testEvent{Data: "hello"})
return repo.Create(testOrder)
repo.Add(testOrder)
return nil
})
assert.NoError(t, res.Error)
assert.NotEmpty(t, testOrder.GetEvents())
......@@ -499,8 +500,8 @@ func BenchmarkUpdateOrders(b *testing.B) {
}
title := fmt.Sprintf("update %d", j)
newID := fmt.Sprintf("new%d", j)
res := NewEngine(nil, &MapExecutor{DB: &db}).Run(context.Background(), func(ctx context.Context, repo Repository) error {
if err := repo.GetManual(ctx, func(ctx context.Context, root ...IEntity) {}, testOrder); err != nil {
res := NewEngine(nil, &MapExecutor{DB: &db}).Run(context.Background(), func(ctx context.Context, repo *Repository) error {
if err := repo.CustomGet(ctx, func(ctx context.Context, root ...IEntity) {}, testOrder); err != nil {
return err
}
testOrder.Title = title
......
......@@ -266,11 +266,12 @@ func TestEngine(t *testing.T) {
eventBus.Start(ctx)
engine := dddfirework.NewEngine(nil, exec_mysql.NewExecutor(db), eventBus.Options()...)
res := engine.NewStage().Main(func(ctx context.Context, repo dddfirework.Repository) error {
res := engine.NewStage().Main(func(ctx context.Context, repo *dddfirework.Repository) error {
e := &testEntity{Name: "hello"}
e.AddEvent(&testEvent{EType: "test_engine", Data: e.Name})
e.AddEvent(&testEvent{EType: "test_engine_tx", Data: e.Name}, dddfirework.WithSendType(dddfirework.SendTypeTransaction))
return repo.Create(e)
repo.Add(e)
return nil
}).Save(ctx)
wg.Wait()
......@@ -312,10 +313,11 @@ func TestTXChecker(t *testing.T) {
eventBus.Start(ctx)
engine := dddfirework.NewEngine(nil, &mockExecutor{}, eventBus.Options()...)
res := engine.NewStage().Main(func(ctx context.Context, repo dddfirework.Repository) error {
res := engine.NewStage().Main(func(ctx context.Context, repo *dddfirework.Repository) error {
e := &testEntity{Name: "test_commit_failed"}
e.AddEvent(&testEvent{EType: "test_commit_failed", Data: e.Name}, dddfirework.WithSendType(dddfirework.SendTypeTransaction))
return repo.Create(e)
repo.Add(e)
return nil
}).Save(ctx)
time.Sleep(time.Second * 1)
......
......@@ -20,13 +20,10 @@ import (
ddd "github.com/bytedance/dddfirework"
"github.com/bytedance/dddfirework/example/biz/sale/domain"
"github.com/bytedance/dddfirework/example/biz/sale/infrastructure/repo"
"github.com/bytedance/dddfirework/example/common/dto/sale"
)
type AddCouponCommand struct {
ddd.Command
orderID string
coupon *sale.Coupon
}
......@@ -42,16 +39,11 @@ func (c *AddCouponCommand) Init(ctx context.Context) ([]string, error) {
return []string{c.orderID}, nil
}
func (c *AddCouponCommand) Build(ctx context.Context, builder ddd.DomainBuilder) (roots []ddd.IEntity, err error) {
order, err := repo.GetOrder(ctx, builder, c.orderID)
if err != nil {
return nil, err
func (c *AddCouponCommand) Main(ctx context.Context, repo *ddd.Repository) error {
order := &domain.Order{ID: c.orderID}
if err := repo.Get(ctx, order); err != nil {
return err
}
return []ddd.IEntity{order}, nil
}
func (c *AddCouponCommand) Act(ctx context.Context, container ddd.RootContainer, roots ...ddd.IEntity) error {
order := roots[0].(*domain.Order)
if err := order.AddCoupon(c.coupon.ID, c.coupon.Rule, c.coupon.Discount); err != nil {
return err
}
......
......@@ -20,13 +20,10 @@ import (
"github.com/bytedance/dddfirework"
"github.com/bytedance/dddfirework/example/biz/sale/domain"
"github.com/bytedance/dddfirework/example/biz/sale/infrastructure/repo"
"github.com/bytedance/dddfirework/example/common/dto/sale"
)
type AddSaleItemCommand struct {
dddfirework.Command
orderID string
item *sale.SaleItem
}
......@@ -39,16 +36,11 @@ func (c *AddSaleItemCommand) Init(ctx context.Context) ([]string, error) {
return []string{c.orderID}, nil
}
func (c *AddSaleItemCommand) Build(ctx context.Context, h dddfirework.DomainBuilder) (roots []dddfirework.IEntity, err error) {
order, err := repo.GetOrder(ctx, h, c.orderID)
if err != nil {
return nil, err
func (c *AddSaleItemCommand) Main(ctx context.Context, repo *dddfirework.Repository) error {
order := &domain.Order{ID: c.orderID}
if err := repo.Get(ctx, order); err != nil {
return err
}
return []dddfirework.IEntity{order}, nil
}
func (c *AddSaleItemCommand) Act(ctx context.Context, container dddfirework.RootContainer, roots ...dddfirework.IEntity) error {
order := roots[0].(*domain.Order)
order.AddSaleItem(c.item.Code, c.item.Name, c.item.Price, c.item.Count)
return nil
}
......@@ -28,8 +28,6 @@ type CreateOrderResult struct {
}
type CreateOrderCommand struct {
dddfirework.Command
userID string
items []*sale.SaleItem
coupons []*sale.Coupon
......@@ -45,26 +43,18 @@ func NewCreateOrderCommand(userID string, items []*sale.SaleItem, coupons []*sal
}
}
func (c *CreateOrderCommand) Act(ctx context.Context, container dddfirework.RootContainer, roots ...dddfirework.IEntity) error {
func (c *CreateOrderCommand) Main(ctx context.Context, repo *dddfirework.Repository) error {
order, err := domain.NewOrder(c.userID, c.items, c.coupons)
if err != nil {
return err
}
container.Add(order)
repo.Add(order)
//Commit 操作为可选项,目的是为了即刻获得待新建的订单 ID,构造返回值
if err := c.Commit(ctx); err != nil {
// 持久化后为新实体自动设置 ID
if err := repo.Save(ctx); err != nil {
return err
}
c.Output(order)
c.Result = &CreateOrderResult{OrderID: order.ID}
return nil
}
func (c *CreateOrderCommand) PostSave(ctx context.Context, res *dddfirework.Result) {
if res.Error == nil {
c.Result = &CreateOrderResult{
OrderID: res.Output.(*domain.Order).ID,
}
}
}
......@@ -23,8 +23,6 @@ import (
)
type DeleteOrderCommand struct {
dddfirework.Command
orderID string
}
......@@ -34,11 +32,7 @@ func NewDeleteOrderCommand(orderID string) *DeleteOrderCommand {
}
}
func (c *DeleteOrderCommand) Build(ctx context.Context, builder dddfirework.DomainBuilder) (roots []dddfirework.IEntity, err error) {
return []dddfirework.IEntity{&domain.Order{ID: c.orderID}}, nil
}
func (c *DeleteOrderCommand) Act(ctx context.Context, container dddfirework.RootContainer, roots ...dddfirework.IEntity) error {
container.Remove(roots[0])
func (c *DeleteOrderCommand) Main(ctx context.Context, repo *dddfirework.Repository) error {
repo.Remove(&domain.Order{ID: c.orderID})
return nil
}
......@@ -20,7 +20,6 @@ import (
"github.com/bytedance/dddfirework"
"github.com/bytedance/dddfirework/example/biz/sale/domain"
"github.com/bytedance/dddfirework/example/biz/sale/infrastructure/repo"
)
type UpdateOrderOpt struct {
......@@ -28,8 +27,6 @@ type UpdateOrderOpt struct {
}
type UpdateOrderCommand struct {
dddfirework.Command
orderID string
opt UpdateOrderOpt
}
......@@ -45,16 +42,11 @@ func (c *UpdateOrderCommand) Init(ctx context.Context) (lockIDs []string, err er
return []string{c.orderID}, nil
}
func (c *UpdateOrderCommand) Build(ctx context.Context, builder dddfirework.DomainBuilder) (roots []dddfirework.IEntity, err error) {
order, err := repo.GetOrder(ctx, builder, c.orderID)
if err != nil {
return nil, err
func (c *UpdateOrderCommand) Main(ctx context.Context, repo *dddfirework.Repository) error {
order := &domain.Order{ID: c.orderID}
if err := repo.Get(ctx, order); err != nil {
return err
}
return []dddfirework.IEntity{order}, nil
}
func (c *UpdateOrderCommand) Act(ctx context.Context, container dddfirework.RootContainer, roots ...dddfirework.IEntity) error {
order := roots[0].(*domain.Order)
order.Update(domain.UpdateOrderOpt{Remark: c.opt.Remark})
return nil
}
......@@ -39,7 +39,7 @@ func (s *SaleServiceImpl) CreateOrder(ctx context.Context, req *CreateOrderReque
cmd := command.NewCreateOrderCommand(
req.User, req.Items, req.Coupons,
)
res := s.engine.RunCommand(ctx, cmd)
res := s.engine.Run(ctx, cmd)
if res.Error != nil {
return nil, res.Error
}
......@@ -48,7 +48,7 @@ func (s *SaleServiceImpl) CreateOrder(ctx context.Context, req *CreateOrderReque
// UpdateOrder implements the SaleServiceImpl interface.
func (s *SaleServiceImpl) UpdateOrder(ctx context.Context, req *UpdateOrderRequest) (resp *UpdateOrderResponse, err error) {
if err := s.engine.RunCommand(ctx, command.NewUpdateOrderCommand(
if err := s.engine.Run(ctx, command.NewUpdateOrderCommand(
req.ID, command.UpdateOrderOpt{Remark: req.Remark},
)).Error; err != nil {
return nil, err
......@@ -58,7 +58,7 @@ func (s *SaleServiceImpl) UpdateOrder(ctx context.Context, req *UpdateOrderReque
// DeleteOrder implements the SaleServiceImpl interface.
func (s *SaleServiceImpl) DeleteOrder(ctx context.Context, req *DeleteOrderRequest) (resp *DeleteOrderResponse, err error) {
if err := s.engine.RunCommand(ctx, command.NewDeleteOrderCommand(
if err := s.engine.Run(ctx, command.NewDeleteOrderCommand(
req.ID,
)).Error; err != nil {
return nil, err
......@@ -100,7 +100,7 @@ func (s *SaleServiceImpl) GetOrderList(ctx context.Context, req *GetOrderListReq
// AddSaleItem implements the SaleServiceImpl interface.
func (s *SaleServiceImpl) AddSaleItem(ctx context.Context, req *AddSaleItemRequest) (resp *AddSaleItemResponse, err error) {
if err := s.engine.RunCommand(ctx, command.NewAddSaleItemCommand(
if err := s.engine.Run(ctx, command.NewAddSaleItemCommand(
req.OrderID, req.Item,
)).Error; err != nil {
return nil, err
......@@ -110,7 +110,7 @@ func (s *SaleServiceImpl) AddSaleItem(ctx context.Context, req *AddSaleItemReque
// AddCoupon implements the SaleServiceImpl interface.
func (s *SaleServiceImpl) AddCoupon(ctx context.Context, req *AddCouponRequest) (resp *AddCouponResponse, err error) {
if err := s.engine.RunCommand(ctx, command.NewAddCouponCommand(
if err := s.engine.Run(ctx, command.NewAddCouponCommand(
req.OrderID, req.Coupon,
)).Error; err != nil {
return nil, err
......
......@@ -341,14 +341,15 @@ func TestRollback(t *testing.T) {
ctx := context.Background()
engine := ddd.NewEngine(testsuit.NewMemLock(), NewExecutor(db))
res := engine.Run(ctx, func(ctx context.Context, repo ddd.Repository) error {
return repo.Create(&order{
res := engine.Run(ctx, func(ctx context.Context, repo *ddd.Repository) error {
repo.Add(&order{
ID: "testrollback",
Title: "testrollback",
}, &order{
ID: "testrollback",
Title: "testrollback",
})
return nil
})
assert.NotNil(t, res.Error)
assert.ErrorIs(t, db.First(&orderPO{}, "id = ?", "testrollback").Error, gorm.ErrRecordNotFound)
......@@ -405,7 +406,7 @@ func BenchmarkUpdate(b *testing.B) {
engine := ddd.NewEngine(nil, NewExecutor(db))
for i := 0; i < b.N; i++ {
if res := engine.NewStage().Main(func(ctx context.Context, repo ddd.Repository) error {
if res := engine.NewStage().Main(func(ctx context.Context, repo *ddd.Repository) error {
testOrder := &order{
ID: "order1",
User: &user{},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment