From 70ccb4ae659d86e016d5cf765db1262a1f809201 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 16 Sep 2024 16:31:08 +0700
Subject: [PATCH] feat: allow arbitrary types in workflow's StartEvent and
 StopEvent (#1210)

---
 .changeset/old-vans-melt.md            |  5 +++++
 packages/core/src/workflow/events.ts   |  4 ++--
 packages/core/src/workflow/workflow.ts |  2 +-
 packages/core/tests/workflow.test.ts   | 26 ++++++++++++++++++++++++++
 4 files changed, 34 insertions(+), 3 deletions(-)
 create mode 100644 .changeset/old-vans-melt.md

diff --git a/.changeset/old-vans-melt.md b/.changeset/old-vans-melt.md
new file mode 100644
index 000000000..ddeb276ec
--- /dev/null
+++ b/.changeset/old-vans-melt.md
@@ -0,0 +1,5 @@
+---
+"@llamaindex/core": patch
+---
+
+Allow arbitrary types in workflow's StartEvent and StopEvent
diff --git a/packages/core/src/workflow/events.ts b/packages/core/src/workflow/events.ts
index 2a8b6992c..64f2975d3 100644
--- a/packages/core/src/workflow/events.ts
+++ b/packages/core/src/workflow/events.ts
@@ -14,5 +14,5 @@ export type EventTypes<T extends Record<string, any> = any> = new (
   data: T,
 ) => WorkflowEvent<T>;
 
-export class StartEvent extends WorkflowEvent<{ input: string }> {}
-export class StopEvent extends WorkflowEvent<{ result: string }> {}
+export class StartEvent<T = string> extends WorkflowEvent<{ input: T }> {}
+export class StopEvent<T = string> extends WorkflowEvent<{ result: T }> {}
diff --git a/packages/core/src/workflow/workflow.ts b/packages/core/src/workflow/workflow.ts
index 307b01506..2cd99f309 100644
--- a/packages/core/src/workflow/workflow.ts
+++ b/packages/core/src/workflow/workflow.ts
@@ -131,7 +131,7 @@ export class Workflow {
     }
   }
 
-  async run(event: StartEvent | string): Promise<StopEvent> {
+  async run<T = string>(event: StartEvent<T> | string): Promise<StopEvent> {
     // Validate the workflow before running if #validate is true
     if (this.#validate) {
       this.validate();
diff --git a/packages/core/tests/workflow.test.ts b/packages/core/tests/workflow.test.ts
index 785a18720..1b881c1c7 100644
--- a/packages/core/tests/workflow.test.ts
+++ b/packages/core/tests/workflow.test.ts
@@ -140,4 +140,30 @@ describe("Workflow", () => {
     expect(result.data.result).toBe("Report generated");
     expect(collectedEvents).toHaveLength(1);
   });
+
+  test("run workflow with object-based StartEvent and StopEvent", async () => {
+    const objectFlow = new Workflow({ verbose: true });
+
+    type Person = { name: string; age: number };
+
+    const processObject = vi.fn(async (_context, ev: StartEvent<Person>) => {
+      const { name, age } = ev.data.input;
+      return new StopEvent({
+        result: { greeting: `Hello ${name}, you are ${age} years old!` },
+      });
+    });
+
+    objectFlow.addStep(StartEvent<Person>, processObject);
+
+    const result = await objectFlow.run(
+      new StartEvent<Person>({
+        input: { name: "Alice", age: 30 },
+      }),
+    );
+
+    expect(processObject).toHaveBeenCalledTimes(1);
+    expect(result.data.result).toEqual({
+      greeting: "Hello Alice, you are 30 years old!",
+    });
+  });
 });
-- 
GitLab