diff --git a/.changeset/old-vans-melt.md b/.changeset/old-vans-melt.md new file mode 100644 index 0000000000000000000000000000000000000000..ddeb276ec1a9cee0a9fa000469f809b0fe9ef8da --- /dev/null +++ b/.changeset/old-vans-melt.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/core": patch +--- + +Allow arbitrary types in workflow's StartEvent and StopEvent diff --git a/packages/core/src/workflow/events.ts b/packages/core/src/workflow/events.ts index 2a8b6992cc3b7918a3d42b4e4134c9858c40f7c3..64f2975d3d8f12d22b414fa6d5b48e5f39a94675 100644 --- a/packages/core/src/workflow/events.ts +++ b/packages/core/src/workflow/events.ts @@ -14,5 +14,5 @@ export type EventTypes<T extends Record<string, any> = any> = new ( data: T, ) => WorkflowEvent<T>; -export class StartEvent extends WorkflowEvent<{ input: string }> {} -export class StopEvent extends WorkflowEvent<{ result: string }> {} +export class StartEvent<T = string> extends WorkflowEvent<{ input: T }> {} +export class StopEvent<T = string> extends WorkflowEvent<{ result: T }> {} diff --git a/packages/core/src/workflow/workflow.ts b/packages/core/src/workflow/workflow.ts index 307b015067f5434c1d6f3df93c6762b5bf7d451a..2cd99f309bfef6d7cc69a18cd5090aeee6f32cef 100644 --- a/packages/core/src/workflow/workflow.ts +++ b/packages/core/src/workflow/workflow.ts @@ -131,7 +131,7 @@ export class Workflow { } } - async run(event: StartEvent | string): Promise<StopEvent> { + async run<T = string>(event: StartEvent<T> | string): Promise<StopEvent> { // Validate the workflow before running if #validate is true if (this.#validate) { this.validate(); diff --git a/packages/core/tests/workflow.test.ts b/packages/core/tests/workflow.test.ts index 785a1872058847d24d38eec85eb76f715aa5423b..1b881c1c7d59e7b671f0f612d9f3e9b43551ccd6 100644 --- a/packages/core/tests/workflow.test.ts +++ b/packages/core/tests/workflow.test.ts @@ -140,4 +140,30 @@ describe("Workflow", () => { expect(result.data.result).toBe("Report generated"); expect(collectedEvents).toHaveLength(1); }); + + test("run workflow with object-based StartEvent and StopEvent", async () => { + const objectFlow = new Workflow({ verbose: true }); + + type Person = { name: string; age: number }; + + const processObject = vi.fn(async (_context, ev: StartEvent<Person>) => { + const { name, age } = ev.data.input; + return new StopEvent({ + result: { greeting: `Hello ${name}, you are ${age} years old!` }, + }); + }); + + objectFlow.addStep(StartEvent<Person>, processObject); + + const result = await objectFlow.run( + new StartEvent<Person>({ + input: { name: "Alice", age: 30 }, + }), + ); + + expect(processObject).toHaveBeenCalledTimes(1); + expect(result.data.result).toEqual({ + greeting: "Hello Alice, you are 30 years old!", + }); + }); });