From 869cf7576276096645351b2d8a5d6b4167f247a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E9=B9=8F=E9=A3=9E?= <jiangpengfei.jiangpf@bytedance.com> Date: Wed, 6 Dec 2023 19:49:33 +0800 Subject: [PATCH] fix: fix engine run --- engine.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/engine.go b/engine.go index 66cb708..8e7fea5 100644 --- a/engine.go +++ b/engine.go @@ -277,7 +277,11 @@ type ActFunc func(ctx context.Context, container RootContainer, roots ...IEntity type MainFunc func(ctx context.Context, repo *Repository) error type PostSaveFunc func(ctx context.Context, res *Result) -// EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型,返回值是 ICommand +// EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型, +// 返回值支持三种: +// - ICommand interface +// - ICommandMain interface +// - MainFunc type // 示例 func(evt *OrderCreatedEvent) *OnEventCreateCommand type EventHandlerConstruct interface{} @@ -478,10 +482,6 @@ func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandle panic("event type must be pointer") } evtType = evtType.Elem() // event type 引用实际类型 - outType := handlerType.Out(0) - if !outType.Implements(cmdType) { - panic("construct output must be type of ICommand") - } constructFunc := reflect.ValueOf(construct) RegisterEventHandler(eventType, func(ctx context.Context, evt *DomainEvent) error { @@ -497,7 +497,8 @@ func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandle } outputs := constructFunc.Call([]reflect.Value{bizEvt}) - if res := e.RunCommand(ctx, outputs[0].Interface().(ICommand)); res.Error != nil { + + if res := e.Run(ctx, outputs[0].Interface()); res.Error != nil { e.logger.Error(res.Error, "event handler exec failed") return res.Error } @@ -767,8 +768,11 @@ func (e *Stage) Run(ctx context.Context, cmd interface{}) *Result { return e.WithOption(options...).Lock(keys...).Main(c.Main).Save(ctx) case func(ctx context.Context, repo *Repository) error: return e.Main(c).Save(ctx) + case MainFunc: + return e.Main(c).Save(ctx) + default: + panic(fmt.Sprintf("cmd type %T is invalid", c)) } - panic("cmd is invalid") } func childrenSnapshot(children map[string][]IEntity) map[string][]IEntity { -- GitLab