From 70ccb4ae659d86e016d5cf765db1262a1f809201 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 16 Sep 2024 16:31:08 +0700 Subject: [PATCH] feat: allow arbitrary types in workflow's StartEvent and StopEvent (#1210) --- .changeset/old-vans-melt.md | 5 +++++ packages/core/src/workflow/events.ts | 4 ++-- packages/core/src/workflow/workflow.ts | 2 +- packages/core/tests/workflow.test.ts | 26 ++++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 .changeset/old-vans-melt.md diff --git a/.changeset/old-vans-melt.md b/.changeset/old-vans-melt.md new file mode 100644 index 000000000..ddeb276ec --- /dev/null +++ b/.changeset/old-vans-melt.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/core": patch +--- + +Allow arbitrary types in workflow's StartEvent and StopEvent diff --git a/packages/core/src/workflow/events.ts b/packages/core/src/workflow/events.ts index 2a8b6992c..64f2975d3 100644 --- a/packages/core/src/workflow/events.ts +++ b/packages/core/src/workflow/events.ts @@ -14,5 +14,5 @@ export type EventTypes<T extends Record<string, any> = any> = new ( data: T, ) => WorkflowEvent<T>; -export class StartEvent extends WorkflowEvent<{ input: string }> {} -export class StopEvent extends WorkflowEvent<{ result: string }> {} +export class StartEvent<T = string> extends WorkflowEvent<{ input: T }> {} +export class StopEvent<T = string> extends WorkflowEvent<{ result: T }> {} diff --git a/packages/core/src/workflow/workflow.ts b/packages/core/src/workflow/workflow.ts index 307b01506..2cd99f309 100644 --- a/packages/core/src/workflow/workflow.ts +++ b/packages/core/src/workflow/workflow.ts @@ -131,7 +131,7 @@ export class Workflow { } } - async run(event: StartEvent | string): Promise<StopEvent> { + async run<T = string>(event: StartEvent<T> | string): Promise<StopEvent> { // Validate the workflow before running if #validate is true if (this.#validate) { this.validate(); diff --git a/packages/core/tests/workflow.test.ts b/packages/core/tests/workflow.test.ts index 785a18720..1b881c1c7 100644 --- a/packages/core/tests/workflow.test.ts +++ b/packages/core/tests/workflow.test.ts @@ -140,4 +140,30 @@ describe("Workflow", () => { expect(result.data.result).toBe("Report generated"); expect(collectedEvents).toHaveLength(1); }); + + test("run workflow with object-based StartEvent and StopEvent", async () => { + const objectFlow = new Workflow({ verbose: true }); + + type Person = { name: string; age: number }; + + const processObject = vi.fn(async (_context, ev: StartEvent<Person>) => { + const { name, age } = ev.data.input; + return new StopEvent({ + result: { greeting: `Hello ${name}, you are ${age} years old!` }, + }); + }); + + objectFlow.addStep(StartEvent<Person>, processObject); + + const result = await objectFlow.run( + new StartEvent<Person>({ + input: { name: "Alice", age: 30 }, + }), + ); + + expect(processObject).toHaveBeenCalledTimes(1); + expect(result.data.result).toEqual({ + greeting: "Hello Alice, you are 30 years old!", + }); + }); }); -- GitLab