Skip to content
Snippets Groups Projects
engine.go 30.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • zhongqiling's avatar
    zhongqiling committed
    //
    // Copyright 2023 Bytedance Ltd. and/or its affiliates
    //
    // Licensed under the Apache License, Version 2.0 (the "License");
    // you may not use this file except in compliance with the License.
    // You may obtain a copy of the License at
    //
    //     http://www.apache.org/licenses/LICENSE-2.0
    //
    // Unless required by applicable law or agreed to in writing, software
    // distributed under the License is distributed on an "AS IS" BASIS,
    // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    // See the License for the specific language governing permissions and
    // limitations under the License.
    
    package dddfirework
    
    import (
    	"context"
    	"encoding/json"
    	"errors"
    	"fmt"
    	stdlog "log"
    	"os"
    	"reflect"
    	"sort"
    	"strings"
    
    	"github.com/go-logr/logr"
    	"github.com/go-logr/stdr"
    	"github.com/rs/xid"
    )
    
    var ErrBreak = fmt.Errorf("break process") // 中断流程,不返回错误
    var ErrEntityNotFound = fmt.Errorf("entity not found")
    var ErrEntityRepeated = fmt.Errorf("entity already added")
    
    var defaultLogger = stdr.New(stdlog.New(os.Stderr, "", stdlog.LstdFlags|stdlog.Lshortfile)).WithName("ddd_engine")
    
    type ILock interface {
    	Lock(ctx context.Context, key string) (keyLock interface{}, err error)
    	UnLock(ctx context.Context, keyLock interface{}) error
    }
    
    type IIDGenerator interface {
    	NewID() (string, error)
    }
    
    type defaultIDGenerator struct {
    }
    
    func (d *defaultIDGenerator) NewID() (string, error) {
    	guid := xid.New()
    	return guid.String(), nil
    }
    
    // EntityContainer 负责维护领域内所有聚合根实体的实体
    type EntityContainer struct {
    	BaseEntity
    
    	roots   []IEntity // 保存聚合根实体
    	deleted []IEntity // 保存所有被删除实体
    }
    
    func (w *EntityContainer) GetChildren() map[string][]IEntity {
    	return map[string][]IEntity{"meta": w.roots}
    }
    
    
    func (w *EntityContainer) Has(root IEntity) bool {
    	for _, r := range w.roots {
    		if r == root {
    			return true
    		}
    	}
    	return false
    }
    
    
    zhongqiling's avatar
    zhongqiling committed
    func (w *EntityContainer) GetDeleted() []IEntity {
    	return w.deleted
    }
    
    func (w *EntityContainer) SetChildren(roots []IEntity) {
    	// EntityContainer 里面会有修改 roots 的操作,应当拷贝一个新的 slice,隔离输入的影响
    	w.roots = make([]IEntity, len(roots))
    	copy(w.roots, roots)
    }
    
    func (w *EntityContainer) Add(root IEntity) error {
    	for _, e := range w.roots {
    		if e == root {
    			return ErrEntityRepeated
    		}
    	}
    
    	w.roots = append(w.roots, root)
    	return nil
    }
    
    func (w *EntityContainer) Remove(root IEntity) error {
    	i := 0
    	for _, item := range w.roots {
    		if item != root {
    			w.roots[i] = item
    			i++
    		}
    	}
    	if i == len(w.roots) {
    		return ErrEntityNotFound
    	}
    	w.roots = w.roots[:i]
    	w.deleted = append(w.deleted, root)
    	return nil
    }
    
    // Recycle 回收所有被删除的实体
    func (w *EntityContainer) Recycle(e IEntity) {
    	w.deleted = append(w.deleted, e)
    }
    
    type ErrList []error
    
    func (e ErrList) Error() string {
    	errs := make([]string, 0)
    	for _, err := range e {
    		errs = append(errs, err.Error())
    	}
    	return strings.Join(errs, ", ")
    }
    
    type Result struct {
    	Error   error
    	Break   bool
    	Actions []*Action
    	Output  interface{}
    }
    
    func ResultErrors(err ...error) *Result {
    	return &Result{Error: ErrList(err)}
    }
    
    func ResultError(err error) *Result {
    	return &Result{Error: err}
    }
    
    func ResultErrOrBreak(err error) *Result {
    	if errors.Is(err, ErrBreak) {
    		return &Result{Break: true}
    	}
    	return ResultError(err)
    }
    
    type DomainBuilder struct {
    	stage *Stage
    }
    
    // Build 查询并构建 parent 以及 children 实体
    // parent 必须指定 id,children 为可选,需要是 *IEntity 或者 *[]IEntity 类型
    func (h DomainBuilder) Build(ctx context.Context, parent IEntity, children ...interface{}) error {
    	return h.stage.BuildEntity(ctx, parent, children...)
    }
    
    // RootContainer 聚合根实体容器
    type RootContainer struct {
    	stage *Stage
    	errs  []error
    }
    
    // Add 创建聚合根实体
    func (h *RootContainer) Add(root IEntity) {
    	if err := h.stage.meta.Add(root); err != nil {
    		h.errs = append(h.errs, err)
    	}
    }
    
    // Remove 删除聚合根实体
    func (h *RootContainer) Remove(root IEntity) {
    	if err := h.stage.meta.Remove(root); err != nil {
    		h.errs = append(h.errs, err)
    	}
    }
    
    
    // Repository 聚合根实体仓库
    type Repository struct {
    	stage *Stage
    }
    
    // Get 查询并构建聚合根
    // root 必须指定 id,children 为可选,是 root 下面子实体的指针,需要是 *IEntity 或者 *[]IEntity 类型
    // 方法会根据 root 与 children 的关系,查询并构建 root 与 children 实体
    func (r *Repository) Get(ctx context.Context, root IEntity, children ...interface{}) error {
    	if r.stage.hasSnapshot(root) {
    		return fmt.Errorf("entity has added")
    	}
    
    	if err := r.stage.BuildEntity(ctx, root, children...); err != nil {
    		return err
    	}
    
    	if err := r.stage.meta.Add(root); err != nil {
    		return err
    	}
    	return r.stage.updateSnapshot()
    }
    
    // GetManual 自定义函数获取聚合根实体,并添加到快照
    func (r *Repository) GetManual(ctx context.Context, getter func(ctx context.Context, root ...IEntity), roots ...IEntity) error {
    	getter(ctx, roots...)
    
    	for _, root := range roots {
    		if r.stage.hasSnapshot(root) {
    			return fmt.Errorf("entity has added")
    		}
    		if err := r.stage.meta.Add(root); err != nil {
    			return err
    		}
    	}
    
    	return r.stage.updateSnapshot()
    }
    
    // Create 创建聚合根
    func (r *Repository) Create(roots ...IEntity) error {
    	for _, root := range roots {
    		if r.stage.hasSnapshot(root) {
    			return fmt.Errorf("root must be a new entity")
    		}
    		if err := r.stage.meta.Add(root); err != nil {
    			return err
    		}
    	}
    	return nil
    }
    
    // Delete 删除聚合根,root.GetID 不能为空
    func (r *Repository) Delete(roots ...IEntity) error {
    	toCreate := make([]IEntity, 0)
    	for _, root := range roots {
    		if !r.stage.meta.Has(root) {
    			toCreate = append(toCreate, root)
    		}
    	}
    	if len(toCreate) > 0 {
    		if err := r.Create(toCreate...); err != nil {
    			return err
    		}
    		if err := r.stage.updateSnapshot(); err != nil {
    			return err
    		}
    	}
    
    	for _, root := range roots {
    		if err := r.stage.meta.Remove(root); err != nil {
    			return err
    		}
    	}
    	return nil
    }
    
    // Save 执行一次保存,并刷新快照
    func (r *Repository) Save(ctx context.Context) error {
    	return r.stage.commit(ctx)
    }
    
    
    zhongqiling's avatar
    zhongqiling committed
    type BuildFunc func(ctx context.Context, h DomainBuilder) (roots []IEntity, err error)
    type ActFunc func(ctx context.Context, container RootContainer, roots ...IEntity) error
    
    type MainFunc func(ctx context.Context, repo Repository) error
    
    zhongqiling's avatar
    zhongqiling committed
    type PostSaveFunc func(ctx context.Context, res *Result)
    
    // EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型,返回值是 ICommand
    // 示例 func(evt *OrderCreatedEvent) *OnEventCreateCommand
    type EventHandlerConstruct interface{}
    
    type Options struct {
    	WithTransaction bool
    	RecursiveDelete bool         // 删除根实体是否递归删除所有子实体
    	EventPersist    EventPersist // 是否保存领域事件到 DB
    	Logger          logr.Logger
    	EventBus        IEventBus
    
    zhongqiling's avatar
    zhongqiling committed
    	IDGenerator     IIDGenerator
    	PostSaveHooks   []PostSaveFunc
    }
    
    type Option interface {
    	ApplyToOptions(*Options)
    }
    type TransactionOption bool
    
    func (t TransactionOption) ApplyToOptions(opts *Options) {
    	opts.WithTransaction = bool(t)
    }
    
    const WithTransaction = TransactionOption(true)
    const WithoutTransaction = TransactionOption(false)
    
    type RecursiveDeleteOption bool
    
    func (t RecursiveDeleteOption) ApplyToOptions(opts *Options) {
    	opts.RecursiveDelete = bool(t)
    }
    
    const WithRecursiveDelete = RecursiveDeleteOption(true)
    
    type LoggerOption struct {
    	logger logr.Logger
    }
    
    func (t LoggerOption) ApplyToOptions(opts *Options) {
    	opts.Logger = t.logger
    }
    
    func WithLogger(logger logr.Logger) LoggerOption {
    	return LoggerOption{logger: logger}
    }
    
    type EventBusOption struct {
    	eventBus IEventBus
    }
    
    func (t EventBusOption) ApplyToOptions(opts *Options) {
    	opts.EventBus = t.eventBus
    }
    
    func WithEventBus(eventBus IEventBus) EventBusOption {
    	return EventBusOption{eventBus: eventBus}
    }
    
    
    type DTimerOption struct {
    	timer ITimer
    }
    
    func (t DTimerOption) ApplyToOptions(opts *Options) {
    	opts.Timer = t.timer
    }
    
    func WithTimer(timer ITimer) DTimerOption {
    	return DTimerOption{timer: timer}
    }
    
    
    zhongqiling's avatar
    zhongqiling committed
    type EventPersist func(event *DomainEvent) (IModel, error)
    
    type EventSaveOption EventPersist
    
    func (t EventSaveOption) ApplyToOptions(opts *Options) {
    	opts.EventPersist = EventPersist(t)
    }
    
    func WithEventPersist(f EventPersist) EventSaveOption {
    	return EventSaveOption(f)
    }
    
    type IDGeneratorOption struct {
    	idGen IIDGenerator
    }
    
    func (t IDGeneratorOption) ApplyToOptions(opts *Options) {
    	opts.IDGenerator = t.idGen
    }
    
    func WithIDGenerator(idGen IIDGenerator) IDGeneratorOption {
    	return IDGeneratorOption{idGen: idGen}
    }
    
    type PostSaveOption PostSaveFunc
    
    func (t PostSaveOption) ApplyToOptions(opts *Options) {
    	opts.PostSaveHooks = append(opts.PostSaveHooks, PostSaveFunc(t))
    }
    
    func WithPostSave(f PostSaveFunc) PostSaveOption {
    	return PostSaveOption(f)
    }
    
    type Engine struct {
    	locker      ILock
    	executor    IExecutor
    	idGenerator IIDGenerator
    	eventbus    IEventBus
    
    zhongqiling's avatar
    zhongqiling committed
    	logger      logr.Logger
    	options     Options
    }
    
    func NewEngine(l ILock, e IExecutor, opts ...Option) *Engine {
    	options := Options{
    		// 默认开启事务
    		WithTransaction: true,
    		Logger:          defaultLogger,
    		IDGenerator:     &defaultIDGenerator{},
    		EventBus:        &noEventBus{},
    
    zhongqiling's avatar
    zhongqiling committed
    	}
    	for _, opt := range opts {
    		opt.ApplyToOptions(&options)
    	}
    	eventBus := options.EventBus
    	eventBus.RegisterEventHandler(onEvent)
    	if txEB, ok := eventBus.(ITransactionEventBus); ok {
    		txEB.RegisterEventTXChecker(onTXChecker)
    	}
    
    	timer := options.Timer
    	timer.RegisterTimerHandler(onTimer)
    
    zhongqiling's avatar
    zhongqiling committed
    	return &Engine{
    		locker:      l,
    		executor:    e,
    		eventbus:    eventBus,
    
    zhongqiling's avatar
    zhongqiling committed
    		options:     options,
    		logger:      options.Logger,
    		idGenerator: options.IDGenerator,
    	}
    }
    
    func (e *Engine) NewStage() *Stage {
    	return &Stage{
    		locker:      e.locker,
    		executor:    e.executor,
    		eventBus:    e.eventbus,
    
    zhongqiling's avatar
    zhongqiling committed
    		idGenerator: e.idGenerator,
    		meta:        &EntityContainer{},
    
    		snapshot:    map[IEntity]*entitySnapshot{},
    
    zhongqiling's avatar
    zhongqiling committed
    		result:      &Result{},
    		options:     e.options,
    		logger:      e.logger,
    	}
    }
    
    func (e *Engine) Create(ctx context.Context, roots ...IEntity) *Result {
    
    	return e.NewStage().Main(func(ctx context.Context, repo Repository) error {
    		return repo.Create(roots...)
    
    zhongqiling's avatar
    zhongqiling committed
    	}).Save(ctx)
    }
    
    func (e *Engine) Delete(ctx context.Context, roots ...IEntity) *Result {
    
    	return e.NewStage().Main(func(ctx context.Context, repo Repository) error {
    		return repo.Delete(roots...)
    
    zhongqiling's avatar
    zhongqiling committed
    	}).Save(ctx)
    }
    
    
    // Deprecated: 请用 Run 方法代替
    
    zhongqiling's avatar
    zhongqiling committed
    func (e *Engine) RunCommand(ctx context.Context, c ICommand, opts ...Option) *Result {
    	return e.NewStage().WithOption(opts...).RunCommand(ctx, c)
    }
    
    
    // Run 运行命令,支持以下格式:
    // 实现 ICommand 接口的对象
    // 实现 ICommandMain 接口的对象
    // 类型为 func(ctx context.Context, repo Repository) error 的函数
    func (e *Engine) Run(ctx context.Context, c interface{}, opts ...Option) *Result {
    	return e.NewStage().WithOption(opts...).Run(ctx, c)
    }
    
    
    zhongqiling's avatar
    zhongqiling committed
    func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandlerConstruct) {
    	handlerType := reflect.TypeOf(construct)
    	if handlerType.Kind() != reflect.Func {
    		panic("construct must type of reflect.Func")
    	}
    	if handlerType.NumIn() != 1 || handlerType.NumOut() != 1 {
    		panic("construct num of arg or output must 1")
    	}
    
    	evtType := handlerType.In(0)
    	if evtType.Kind() != reflect.Ptr {
    		panic("event type must be pointer")
    	}
    	evtType = evtType.Elem() // event type 引用实际类型
    	outType := handlerType.Out(0)
    	if !outType.Implements(cmdType) {
    		panic("construct output must be type of ICommand")
    	}
    	constructFunc := reflect.ValueOf(construct)
    
    	RegisterEventHandler(eventType, func(ctx context.Context, evt *DomainEvent) error {
    		var bizEvt reflect.Value
    		if evtType == domainEventType {
    			bizEvt = reflect.ValueOf(evt)
    		} else {
    			bizEvt = reflect.New(evtType)
    			if err := json.Unmarshal(evt.Payload, bizEvt.Interface()); err != nil {
    				e.logger.Error(err, "unmarshal event failed")
    				return err
    			}
    		}
    
    		outputs := constructFunc.Call([]reflect.Value{bizEvt})
    		if res := e.RunCommand(ctx, outputs[0].Interface().(ICommand)); res.Error != nil {
    			e.logger.Error(res.Error, "event handler exec failed")
    			return res.Error
    		}
    		return nil
    	})
    }
    
    
    // RegisterCronTask 注册定时任务
    func (e *Engine) RegisterCronTask(key EventType, cron string, f func(key, cron string)) {
    	if e.timer == nil {
    		panic("No ITimer specified")
    	}
    	if hasEventHandler(key) {
    		panic("key has registered")
    	}
    
    	RegisterEventHandler(key, func(ctx context.Context, evt *TimerEvent) error {
    		f(evt.Key, evt.Cron)
    		return nil
    	})
    
    	if err := e.timer.RunCron(string(key), cron, nil); err != nil {
    
    		panic(err)
    	}
    }
    
    // RegisterCronTaskOfCommand 注册定时触发的 ICommand
    func (e *Engine) RegisterCronTaskOfCommand(key EventType, cron string, f func(key, cron string) ICommand) {
    	if e.timer == nil {
    		panic("No ITimer specified")
    	}
    	if hasEventHandler(key) {
    		panic("key has registered")
    	}
    
    	e.RegisterEventHandler(key, func(evt *TimerEvent) ICommand {
    		return f(evt.Key, evt.Cron)
    	})
    	if err := e.timer.RunCron(string(key), cron, nil); err != nil {
    
    zhongqiling's avatar
    zhongqiling committed
    // Stage 取舞台的意思,表示单次运行
    type Stage struct {
    	lockKeys []string
    
    zhongqiling's avatar
    zhongqiling committed
    
    	locker      ILock
    	executor    IExecutor
    	eventBus    IEventBus
    
    zhongqiling's avatar
    zhongqiling committed
    	idGenerator IIDGenerator
    	logger      logr.Logger
    	options     Options
    
    	meta     *EntityContainer
    	snapshot entitySnapshotPool
    	result   *Result
    	eventCtx context.Context
    }
    
    func (e *Stage) WithOption(opts ...Option) *Stage {
    	for _, opt := range opts {
    		opt.ApplyToOptions(&e.options)
    	}
    
    	eventBus := e.options.EventBus
    	eventBus.RegisterEventHandler(onEvent)
    	if txEB, ok := eventBus.(ITransactionEventBus); ok {
    		txEB.RegisterEventTXChecker(onTXChecker)
    	}
    	e.eventBus = eventBus
    
    
    	timer := e.options.Timer
    	timer.RegisterTimerHandler(onTimer)
    	e.timer = timer
    
    zhongqiling's avatar
    zhongqiling committed
    	e.logger = e.options.Logger
    	e.idGenerator = e.options.IDGenerator
    	return e
    }
    
    func changeType2OP(t changeType) OpType {
    	switch t {
    	case newChildren:
    		return OpInsert
    	case dirtyChildren:
    		return OpUpdate
    	case deleteChildren, clearChildren:
    		return OpDelete
    	}
    	return OpUnknown
    }
    
    // BuildEntity 查询并构建 parent 以及 children 实体
    // parent 必须指定 id,children 为可选,需要是 *IEntity 或者 *[]IEntity 类型
    func (e *Stage) BuildEntity(ctx context.Context, parent IEntity, children ...interface{}) error {
    	if parent.GetID() == "" {
    		return fmt.Errorf("parent must has ID")
    	}
    
    	if err := e.buildEntity(ctx, parent, nil); err != nil {
    		return err
    	}
    
    	for _, item := range children {
    		itemType := reflect.TypeOf(item)
    		if itemType.Kind() != reflect.Ptr {
    			return fmt.Errorf("children must be pointer")
    		}
    		if itemType.Elem().Kind() == reflect.Slice {
    			if err := e.buildEntitySliceByParent(ctx, parent, item); err != nil {
    				return err
    			}
    		} else if itemType.Implements(entityType) {
    			if err := e.buildEntity(ctx, item.(IEntity), parent); err != nil && !errors.Is(err, ErrEntityNotFound) {
    				return err
    			}
    		} else {
    			return fmt.Errorf("children type must be IEntity")
    		}
    	}
    	return nil
    }
    
    // 查询并构建 entity 实体,注意,不会处理 parent 实体
    func (e *Stage) buildEntity(ctx context.Context, entity, parent IEntity) error {
    	// 至少一个有 ID
    	if entity.GetID() == "" && parent.GetID() == "" {
    		return fmt.Errorf("entity to build must has id")
    	}
    	po, err := e.executor.Entity2Model(entity, parent, OpQuery)
    	if err != nil {
    		return err
    	}
    	posPointer := reflect.New(reflect.SliceOf(reflect.TypeOf(po)))
    	if err := e.executor.Exec(ctx, &Action{
    		Op:          OpQuery,
    		Query:       po,
    		QueryResult: posPointer.Interface(),
    	}); err != nil {
    		return err
    	}
    	if posPointer.Elem().Len() == 0 {
    		return ErrEntityNotFound
    	}
    	queryPO := posPointer.Elem().Index(0).Interface()
    	return e.executor.Model2Entity(queryPO.(IModel), entity)
    }
    
    func (e *Stage) buildEntitySliceByParent(ctx context.Context, parent IEntity, children interface{}) error {
    	childrenType := reflect.TypeOf(children)
    	if childrenType.Kind() != reflect.Ptr || childrenType.Elem().Kind() != reflect.Slice {
    		return fmt.Errorf("children must be pointer of slice")
    	}
    	eType := childrenType.Elem().Elem()
    	if !eType.Implements(entityType) {
    		return fmt.Errorf("element of children must implement IEntity")
    	}
    
    	if eType.Kind() == reflect.Ptr {
    		eType = eType.Elem()
    	}
    	entity := reflect.New(eType)
    	po, err := e.executor.Entity2Model(entity.Interface().(IEntity), parent, OpQuery)
    	if err != nil {
    		return err
    	}
    	posPointer := reflect.New(reflect.SliceOf(reflect.TypeOf(po)))
    
    	if err := e.executor.Exec(ctx, &Action{
    		Op:          OpQuery,
    		Query:       po,
    		QueryResult: posPointer.Interface(),
    	}); err != nil {
    		return err
    	}
    	if posPointer.Elem().Len() == 0 {
    		return nil
    	}
    
    	resultVal := reflect.ValueOf(children)
    	entitiesVal := resultVal.Elem()
    	for i := 0; i < posPointer.Elem().Len(); i++ {
    		newEntity := reflect.New(eType)
    		if err := e.executor.Model2Entity(posPointer.Elem().Index(i).Interface().(IModel), newEntity.Interface().(IEntity)); err != nil {
    			return err
    		}
    		entitiesVal = reflect.Append(entitiesVal, newEntity)
    	}
    	resultVal.Elem().Set(entitiesVal)
    	return nil
    }
    
    func (e *Stage) Lock(keys ...string) *Stage {
    	e.lockKeys = keys
    	return e
    }
    
    
    func (e *Stage) Main(f MainFunc) *Stage {
    	e.main = f
    
    zhongqiling's avatar
    zhongqiling committed
    	return e
    }
    
    
    // Deprecated: 请用 Run 方法代替
    func (e *Stage) RunCommand(ctx context.Context, c ICommand) *Result {
    	return e.Run(ctx, c)
    
    zhongqiling's avatar
    zhongqiling committed
    }
    
    
    func (e *Stage) runCommand(ctx context.Context, c ICommand) *Result {
    
    zhongqiling's avatar
    zhongqiling committed
    	if setter, ok := c.(IStageSetter); ok {
    		setter.SetStage(StageAgent{st: e})
    	}
    
    	keys, err := c.Init(ctx)
    	if err != nil {
    		return ResultErrOrBreak(err)
    	}
    
    	return e.WithOption(PostSaveOption(c.PostSave)).Lock(keys...).Main(func(ctx context.Context, repo Repository) error {
    		buildRoots, err := c.Build(ctx, DomainBuilder{stage: repo.stage})
    		if err != nil {
    			return err
    		}
    		for _, r := range buildRoots {
    			if r.GetID() == "" {
    				return fmt.Errorf("build entities must have ID, for create case, just use container.Add(**) at act func")
    			}
    		}
    		repo.stage.meta.SetChildren(buildRoots)
    
    		// 保存父子实体关系链
    		if err = repo.stage.flush(); err != nil {
    			return err
    		}
    
    		container := RootContainer{stage: repo.stage}
    		if err := c.Act(ctx, container, buildRoots...); err != nil {
    			return err
    		} else if len(container.errs) > 0 {
    			return ErrList(container.errs)
    		}
    		return nil
    	}).Save(ctx)
    }
    
    // Run 运行命令,支持以下格式:
    // 实现 ICommand 接口的对象
    // 实现 ICommandMain 接口的对象
    // 类型为 func(ctx context.Context, repo Repository) error 的函数
    func (e *Stage) Run(ctx context.Context, cmd interface{}) *Result {
    	switch c := cmd.(type) {
    	case ICommand:
    		return e.runCommand(ctx, c)
    	case ICommandMain:
    		var keys []string
    		var options []Option
    		if cmdInit, ok := cmd.(ICommandInit); ok {
    			initKeys, err := cmdInit.Init(ctx)
    			if err != nil {
    				return ResultErrOrBreak(err)
    			}
    			keys = initKeys
    		}
    		if cmdPostSave, ok := cmd.(ICommandPostSave); ok {
    			options = append(options, PostSaveOption(cmdPostSave.PostSave))
    		}
    		return e.WithOption(options...).Lock(keys...).Main(c.Main).Save(ctx)
    	case func(ctx context.Context, repo Repository) error:
    		return e.Main(c).Save(ctx)
    	}
    	panic("cmd is invalid")
    
    zhongqiling's avatar
    zhongqiling committed
    }
    
    func childrenSnapshot(children map[string][]IEntity) map[string][]IEntity {
    	snapshot := make(map[string][]IEntity, len(children))
    	for k, v := range children {
    		v2 := make([]IEntity, len(v))
    		copy(v2, v)
    		snapshot[k] = v2
    	}
    	return snapshot
    }
    
    
    func (e *Stage) flush() error {
    
    zhongqiling's avatar
    zhongqiling committed
    	e.snapshot = entitySnapshotPool{}
    
    	return e.updateSnapshot()
    }
    
    // 更新快照,已有的不覆盖
    func (e *Stage) updateSnapshot() error {
    
    zhongqiling's avatar
    zhongqiling committed
    	return walk(e.meta, nil, func(entity, parent IEntity, children map[string][]IEntity) error {
    
    		if _, in := e.snapshot[entity]; in && entity != e.meta {
    
    zhongqiling's avatar
    zhongqiling committed
    			return nil
    		}
    
    zhongqiling's avatar
    zhongqiling committed
    		po, err := e.executor.Entity2Model(entity, parent, OpQuery)
    		if err != nil && !errors.Is(ErrEntityNotRegister, err) {
    			return err
    		}
    		e.snapshot[entity] = &entitySnapshot{
    			po:       po,
    			children: childrenSnapshot(children),
    		}
    		return nil
    	})
    }
    
    
    func (e *Stage) hasSnapshot(target IEntity) bool {
    	return e.snapshot[target] != nil
    }
    
    
    zhongqiling's avatar
    zhongqiling committed
    // unDirty 对所有实体取消 Dirty 标记
    func (e *Stage) unDirty() {
    	_ = walk(e.meta, nil, func(entity, parent IEntity, children map[string][]IEntity) error {
    		entity.UnDirty()
    		return nil
    	})
    }
    
    func (e *Stage) getEntityChanged() ([]*entityChanged, error) {
    	changed := entityDiff(e.meta, e.snapshot)
    	// 处理实体移动的场景
    	changed, err := handleEntityMove(changed)
    	if err != nil {
    		return nil, err
    	}
    	// 处理递归删除子实体
    	if e.options.RecursiveDelete {
    		changed = recursiveDelete(changed)
    	}
    	return changed, nil
    }
    
    func (e *Stage) recycle(changed []*entityChanged) {
    	for _, c := range changed {
    		if c.changeType == deleteChildren {
    			for _, entity := range c.children {
    				e.meta.Recycle(entity)
    			}
    		}
    	}
    }
    
    // 为新对象统一生成id
    func (e *Stage) putNewID(changes []*entityChanged) error {
    	for _, item := range changes {
    		if item.changeType == newChildren {
    			for _, child := range item.children {
    				if child.GetID() == "" {
    					id, err := e.idGenerator.NewID()
    					if err != nil {
    						return err
    					}
    					child.SetID(id)
    				}
    			}
    		}
    	}
    	return nil
    }
    
    // 相同操作,相同类型的PO,合并到一个 Action
    func (e *Stage) makeActions(changes []*entityChanged) ([]*Action, error) {
    	typeActions := make(map[OpType]map[reflect.Type]*Action, 3)
    	for _, item := range changes {
    		for _, entity := range item.children {
    			op := changeType2OP(item.changeType)
    			po, err := e.executor.Entity2Model(entity, item.parent, op)
    			if err != nil {
    				if errors.Is(err, ErrEntityNotRegister) {
    					e.logger.Info("entity not registered", "type", reflect.TypeOf(entity))
    					continue
    				}
    				return nil, err
    			}
    			poType := reflect.TypeOf(po)
    			if _, in := typeActions[op]; !in {
    				typeActions[op] = map[reflect.Type]*Action{}
    			}
    
    			if _, in := typeActions[op][poType]; in {
    				typeActions[op][poType].Models = append(typeActions[op][poType].Models, po)
    			} else {
    				typeActions[op][poType] = &Action{
    					Op:     op,
    					Models: []IModel{po},
    				}
    			}
    			if op == OpUpdate {
    				typeActions[op][poType].PrevModels = append(typeActions[op][poType].PrevModels, e.snapshot[entity].po)
    			}
    		}
    	}
    	actions := make([]*Action, 0)
    	for _, t := range []OpType{OpInsert, OpUpdate, OpDelete} {
    		for _, a := range typeActions[t] {
    			actions = append(actions, a)
    		}
    	}
    	return actions, nil
    }
    
    func (e *Stage) collectEvents() []*DomainEvent {
    	eventMap := make(map[string]*DomainEvent, 0)
    	_ = walk(e.meta, nil, func(entity, parent IEntity, children map[string][]IEntity) (err error) {
    		for _, evt := range entity.GetEvents() {
    			eventMap[evt.ID] = evt
    		}
    		return
    	})
    	// 收集已删除的实体发送的事件
    	for _, del := range e.meta.GetDeleted() {
    		_ = walk(del, nil, func(entity, parent IEntity, children map[string][]IEntity) (err error) {
    			for _, evt := range entity.GetEvents() {
    				eventMap[evt.ID] = evt
    			}
    			return
    		})
    	}
    	events := make([]*DomainEvent, 0)
    	for _, evt := range eventMap {
    		events = append(events, evt)
    	}
    	// 事件根据发送时间 + id 的顺序排序,id 由 xid 保证单节点严格自增
    	sort.SliceStable(events, func(i, j int) bool {
    		return events[i].CreatedAt.Before(events[j].CreatedAt) ||
    			(events[i].CreatedAt.Equal(events[j].CreatedAt) && events[i].ID < events[j].ID)
    	})
    	return events
    }
    
    func (e *Stage) makeEventPersistAction(events []*DomainEvent) (*Action, error) {
    	pos := make([]IModel, len(events))
    	for i, evt := range events {
    		po, err := e.options.EventPersist(evt)
    		if err != nil {
    			return nil, err
    		}
    		pos[i] = po
    	}
    	return &Action{
    		Op:         OpInsert,
    		Models:     pos,
    		PrevModels: []IModel{},
    	}, nil
    }
    
    func (e *Stage) dispatchEvents(ctx context.Context, events []*DomainEvent) (err error) {
    	if !e.options.WithTransaction {
    		e.logger.Info("engine not support transaction")
    		return e.eventBus.Dispatch(ctx, events...)
    	}
    
    	normalEvents, txEvents := make([]*DomainEvent, 0), make([]*DomainEvent, 0)
    	for _, evt := range events {
    		if evt.SendType == SendTypeTransaction {
    			txEvents = append(txEvents, evt)
    		} else {
    			normalEvents = append(normalEvents, evt)
    		}
    	}
    	if txEventBus, ok := e.eventBus.(ITransactionEventBus); ok {
    		if len(txEvents) > 0 {
    			e.eventCtx, err = txEventBus.DispatchBegin(ctx, txEvents...)
    			if err != nil {
    				return err
    			}
    		}
    
    		if len(normalEvents) > 0 {
    			if err = txEventBus.Dispatch(ctx, normalEvents...); err != nil {
    				return err
    			}
    		}
    		return
    	} else {
    		// 如果 eventbus 不支持事务,所有事件默认按照普通方式发送
    		return e.eventBus.Dispatch(ctx, events...)
    	}
    }
    
    func (e *Stage) commit(ctx context.Context) error {
    	if err := e.persist(ctx); err != nil {
    		return err
    	}
    
    	if err := e.flush(); err != nil {
    
    zhongqiling's avatar
    zhongqiling committed
    		return err
    	}
    
    	e.unDirty()
    	return nil
    }
    
    func (e *Stage) persist(ctx context.Context) error {
    	// 发现实体变更
    	changed, err := e.getEntityChanged()
    	if err != nil {
    		return err
    	}
    	// 执行保存前的 hook
    	for _, c := range changed {
    		for _, entity := range c.children {
    			if err := execHook(ctx, entity, c.changeType, true); err != nil {
    				return err
    			}
    		}
    	}
    
    	if err := e.putNewID(changed); err != nil {
    		return err
    	}
    	// 转换为持久化 Action 序列
    	actions, err := e.makeActions(changed)