From c3d9b2f12daf8e459ff71deb2081f6dc04a65c5b Mon Sep 17 00:00:00 2001
From: thucpn <thucsh2@gmail.com>
Date: Wed, 19 Mar 2025 16:16:12 +0700
Subject: [PATCH] support agent workflow

---
 packages/server/src/handlers/chat.ts          | 18 +++++--------
 packages/server/src/types.ts                  | 10 ++++---
 .../src/utils/{stream.ts => workflow.ts}      | 26 +++++++++++++++++--
 3 files changed, 36 insertions(+), 18 deletions(-)
 rename packages/server/src/utils/{stream.ts => workflow.ts} (66%)

diff --git a/packages/server/src/handlers/chat.ts b/packages/server/src/handlers/chat.ts
index 8f2d1a7e9..d8b56308c 100644
--- a/packages/server/src/handlers/chat.ts
+++ b/packages/server/src/handlers/chat.ts
@@ -1,4 +1,4 @@
-import { LlamaIndexAdapter } from "ai";
+import { type Message } from "ai";
 import { IncomingMessage, ServerResponse } from "http";
 import { type ChatMessage } from "llamaindex";
 import type { ServerWorkflow } from "../types";
@@ -7,7 +7,7 @@ import {
   pipeResponse,
   sendJSONResponse,
 } from "../utils/request";
-import { createStreamFromWorkflowContext } from "../utils/stream";
+import { runWorkflow } from "../utils/workflow";
 
 export const handleChat = async (
   workflow: ServerWorkflow,
@@ -16,7 +16,7 @@ export const handleChat = async (
 ) => {
   try {
     const body = await parseRequestBody(req);
-    const { messages } = body as { messages: ChatMessage[] };
+    const { messages } = body as { messages: Message[] };
 
     const lastMessage = messages[messages.length - 1];
     if (lastMessage?.role !== "user") {
@@ -25,15 +25,9 @@ export const handleChat = async (
       });
     }
 
-    const userMessage = lastMessage.content;
-    const chatHistory = messages.slice(0, -1);
-
-    const context = workflow.run({ userMessage, chatHistory });
-    const { stream, dataStream } =
-      await createStreamFromWorkflowContext(context);
-    const streamResponse = LlamaIndexAdapter.toDataStreamResponse(stream, {
-      data: dataStream,
-    });
+    const userInput = lastMessage.content;
+    const chatHistory = messages.slice(0, -1) as ChatMessage[];
+    const streamResponse = await runWorkflow(workflow, userInput, chatHistory);
     pipeResponse(res, streamResponse);
   } catch (error) {
     console.error("Chat error:", error);
diff --git a/packages/server/src/types.ts b/packages/server/src/types.ts
index 9a41a7dce..c0d0744fc 100644
--- a/packages/server/src/types.ts
+++ b/packages/server/src/types.ts
@@ -1,13 +1,15 @@
 import {
+  AgentWorkflow,
   Workflow,
   type ChatMessage,
   type ChatResponseChunk,
-  type MessageContent,
 } from "llamaindex";
 
 export type AgentInput = {
-  userMessage: MessageContent;
-  chatHistory: ChatMessage[];
+  userInput: string; // the last message content from the user
+  chatHistory: ChatMessage[]; // the previous chat history (not including the last message)
 };
 
-export type ServerWorkflow = Workflow<null, AgentInput, ChatResponseChunk>;
+export type ServerWorkflow =
+  | Workflow<null, AgentInput, ChatResponseChunk>
+  | AgentWorkflow;
diff --git a/packages/server/src/utils/stream.ts b/packages/server/src/utils/workflow.ts
similarity index 66%
rename from packages/server/src/utils/stream.ts
rename to packages/server/src/utils/workflow.ts
index 056a5ac82..328f9f975 100644
--- a/packages/server/src/utils/stream.ts
+++ b/packages/server/src/utils/workflow.ts
@@ -1,14 +1,36 @@
-import { StreamData, type JSONValue } from "ai";
+import { LlamaIndexAdapter, StreamData, type JSONValue } from "ai";
 import {
+  AgentWorkflow,
   EngineResponse,
   StopEvent,
   WorkflowContext,
   WorkflowEvent,
+  type ChatMessage,
   type ChatResponseChunk,
 } from "llamaindex";
 import { ReadableStream } from "stream/web";
+import type { ServerWorkflow } from "../types";
 
-export async function createStreamFromWorkflowContext<Input, Output, Context>(
+export async function runWorkflow(
+  workflow: ServerWorkflow,
+  userInput: string,
+  chatHistory: ChatMessage[],
+) {
+  if (workflow instanceof AgentWorkflow) {
+    const context = workflow.run(userInput, { chatHistory });
+    const { stream, dataStream } = await createStreamFromWorkflowContext(
+      // eslint-disable-next-line @typescript-eslint/no-explicit-any
+      context as any,
+    );
+    return LlamaIndexAdapter.toDataStreamResponse(stream, { data: dataStream });
+  }
+
+  const context = workflow.run({ userInput, chatHistory });
+  const { stream, dataStream } = await createStreamFromWorkflowContext(context);
+  return LlamaIndexAdapter.toDataStreamResponse(stream, { data: dataStream });
+}
+
+async function createStreamFromWorkflowContext<Input, Output, Context>(
   context: WorkflowContext<Input, Output, Context>,
 ): Promise<{ stream: ReadableStream<EngineResponse>; dataStream: StreamData }> {
   const dataStream = new StreamData();
-- 
GitLab