From 869cf7576276096645351b2d8a5d6b4167f247a5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=B1=9F=E9=B9=8F=E9=A3=9E?=
 <jiangpengfei.jiangpf@bytedance.com>
Date: Wed, 6 Dec 2023 19:49:33 +0800
Subject: [PATCH] fix: fix engine run

---
 engine.go | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/engine.go b/engine.go
index 66cb708..8e7fea5 100644
--- a/engine.go
+++ b/engine.go
@@ -277,7 +277,11 @@ type ActFunc func(ctx context.Context, container RootContainer, roots ...IEntity
 type MainFunc func(ctx context.Context, repo *Repository) error
 type PostSaveFunc func(ctx context.Context, res *Result)
 
-// EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型,返回值是 ICommand
+// EventHandlerConstruct EventHandler 的构造函数,带一个入参和一个返回值,入参是与事件类型匹配的事件数据指针类型,
+// 返回值支持三种:
+// - ICommand interface
+// - ICommandMain interface
+// - MainFunc type
 // 示例 func(evt *OrderCreatedEvent) *OnEventCreateCommand
 type EventHandlerConstruct interface{}
 
@@ -478,10 +482,6 @@ func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandle
 		panic("event type must be pointer")
 	}
 	evtType = evtType.Elem() // event type 引用实际类型
-	outType := handlerType.Out(0)
-	if !outType.Implements(cmdType) {
-		panic("construct output must be type of ICommand")
-	}
 	constructFunc := reflect.ValueOf(construct)
 
 	RegisterEventHandler(eventType, func(ctx context.Context, evt *DomainEvent) error {
@@ -497,7 +497,8 @@ func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandle
 		}
 
 		outputs := constructFunc.Call([]reflect.Value{bizEvt})
-		if res := e.RunCommand(ctx, outputs[0].Interface().(ICommand)); res.Error != nil {
+
+		if res := e.Run(ctx, outputs[0].Interface()); res.Error != nil {
 			e.logger.Error(res.Error, "event handler exec failed")
 			return res.Error
 		}
@@ -767,8 +768,11 @@ func (e *Stage) Run(ctx context.Context, cmd interface{}) *Result {
 		return e.WithOption(options...).Lock(keys...).Main(c.Main).Save(ctx)
 	case func(ctx context.Context, repo *Repository) error:
 		return e.Main(c).Save(ctx)
+	case MainFunc:
+		return e.Main(c).Save(ctx)
+	default:
+		panic(fmt.Sprintf("cmd type %T is invalid", c))
 	}
-	panic("cmd is invalid")
 }
 
 func childrenSnapshot(children map[string][]IEntity) map[string][]IEntity {
-- 
GitLab