From 3d414883012dd9ecda9bf17e5f2ddf200501574c Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Thu, 10 Oct 2024 18:37:55 +0700
Subject: [PATCH] feat: use selected llamacloud for multiagent (#359)

---
 .changeset/stupid-paws-push.md                |  5 ++++
 .../typescript/streaming/annotations.ts       | 23 +++++++++++++++
 .../typescript/express/chat.controller.ts     |  7 ++---
 .../multiagent/typescript/nextjs/route.ts     |  7 ++---
 .../multiagent/typescript/workflow/agents.ts  | 21 +++++++++-----
 .../multiagent/typescript/workflow/factory.ts | 29 ++++++++++---------
 .../typescript/workflow/single-agent.ts       |  2 +-
 .../multiagent/typescript/workflow/tools.ts   |  6 ++--
 8 files changed, 67 insertions(+), 33 deletions(-)
 create mode 100644 .changeset/stupid-paws-push.md

diff --git a/.changeset/stupid-paws-push.md b/.changeset/stupid-paws-push.md
new file mode 100644
index 00000000..716a0ef5
--- /dev/null
+++ b/.changeset/stupid-paws-push.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+feat: use selected llamacloud for multiagent
diff --git a/templates/components/llamaindex/typescript/streaming/annotations.ts b/templates/components/llamaindex/typescript/streaming/annotations.ts
index 13842c7a..10e6f52c 100644
--- a/templates/components/llamaindex/typescript/streaming/annotations.ts
+++ b/templates/components/llamaindex/typescript/streaming/annotations.ts
@@ -172,3 +172,26 @@ function getValidAnnotation(annotation: JSONValue): Annotation {
   }
   return { type: annotation.type, data: annotation.data };
 }
+
+// validate and get all annotations of a specific type or role from the frontend messages
+export function getAnnotations<
+  T extends Annotation["data"] = Annotation["data"],
+>(
+  messages: Message[],
+  options?: {
+    role?: Message["role"]; // message role
+    type?: Annotation["type"]; // annotation type
+  },
+): {
+  type: string;
+  data: T;
+}[] {
+  const messagesByRole = options?.role
+    ? messages.filter((msg) => msg.role === options?.role)
+    : messages;
+  const annotations = getAllAnnotations(messagesByRole);
+  const annotationsByType = options?.type
+    ? annotations.filter((a) => a.type === options.type)
+    : annotations;
+  return annotationsByType as { type: string; data: T }[];
+}
diff --git a/templates/components/multiagent/typescript/express/chat.controller.ts b/templates/components/multiagent/typescript/express/chat.controller.ts
index 46be6d78..8dfaf6c4 100644
--- a/templates/components/multiagent/typescript/express/chat.controller.ts
+++ b/templates/components/multiagent/typescript/express/chat.controller.ts
@@ -1,13 +1,13 @@
 import { StopEvent } from "@llamaindex/core/workflow";
 import { Message, streamToResponse } from "ai";
 import { Request, Response } from "express";
-import { ChatMessage, ChatResponseChunk } from "llamaindex";
+import { ChatResponseChunk } from "llamaindex";
 import { createWorkflow } from "./workflow/factory";
 import { toDataStream, workflowEventsToStreamData } from "./workflow/stream";
 
 export const chat = async (req: Request, res: Response) => {
   try {
-    const { messages }: { messages: Message[] } = req.body;
+    const { messages, data }: { messages: Message[]; data?: any } = req.body;
     const userMessage = messages.pop();
     if (!messages || !userMessage || userMessage.role !== "user") {
       return res.status(400).json({
@@ -16,8 +16,7 @@ export const chat = async (req: Request, res: Response) => {
       });
     }
 
-    const chatHistory = messages as ChatMessage[];
-    const agent = createWorkflow(chatHistory);
+    const agent = createWorkflow(messages, data);
     const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
       userMessage.content,
     ) as unknown as Promise<StopEvent<AsyncGenerator<ChatResponseChunk>>>;
diff --git a/templates/components/multiagent/typescript/nextjs/route.ts b/templates/components/multiagent/typescript/nextjs/route.ts
index 04b40339..2f93e0f7 100644
--- a/templates/components/multiagent/typescript/nextjs/route.ts
+++ b/templates/components/multiagent/typescript/nextjs/route.ts
@@ -1,7 +1,7 @@
 import { initObservability } from "@/app/observability";
 import { StopEvent } from "@llamaindex/core/workflow";
 import { Message, StreamingTextResponse } from "ai";
-import { ChatMessage, ChatResponseChunk } from "llamaindex";
+import { ChatResponseChunk } from "llamaindex";
 import { NextRequest, NextResponse } from "next/server";
 import { initSettings } from "./engine/settings";
 import { createWorkflow } from "./workflow/factory";
@@ -16,7 +16,7 @@ export const dynamic = "force-dynamic";
 export async function POST(request: NextRequest) {
   try {
     const body = await request.json();
-    const { messages }: { messages: Message[] } = body;
+    const { messages, data }: { messages: Message[]; data?: any } = body;
     const userMessage = messages.pop();
     if (!messages || !userMessage || userMessage.role !== "user") {
       return NextResponse.json(
@@ -28,8 +28,7 @@ export async function POST(request: NextRequest) {
       );
     }
 
-    const chatHistory = messages as ChatMessage[];
-    const agent = createWorkflow(chatHistory);
+    const agent = createWorkflow(messages, data);
     // TODO: fix type in agent.run in LITS
     const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
       userMessage.content,
diff --git a/templates/components/multiagent/typescript/workflow/agents.ts b/templates/components/multiagent/typescript/workflow/agents.ts
index 6af2bf94..71f3123c 100644
--- a/templates/components/multiagent/typescript/workflow/agents.ts
+++ b/templates/components/multiagent/typescript/workflow/agents.ts
@@ -1,14 +1,19 @@
 import { ChatMessage } from "llamaindex";
 import { FunctionCallingAgent } from "./single-agent";
-import { lookupTools } from "./tools";
+import { getQueryEngineTool, lookupTools } from "./tools";
 
-export const createResearcher = async (chatHistory: ChatMessage[]) => {
-  const tools = await lookupTools([
-    "query_index",
-    "wikipedia_tool",
-    "duckduckgo_search",
-    "image_generator",
-  ]);
+export const createResearcher = async (
+  chatHistory: ChatMessage[],
+  params?: any,
+) => {
+  const queryEngineTool = await getQueryEngineTool(params);
+  const tools = (
+    await lookupTools([
+      "wikipedia_tool",
+      "duckduckgo_search",
+      "image_generator",
+    ])
+  ).concat(queryEngineTool ? [queryEngineTool] : []);
 
   return new FunctionCallingAgent({
     name: "researcher",
diff --git a/templates/components/multiagent/typescript/workflow/factory.ts b/templates/components/multiagent/typescript/workflow/factory.ts
index 2aef2c25..0e341ca2 100644
--- a/templates/components/multiagent/typescript/workflow/factory.ts
+++ b/templates/components/multiagent/typescript/workflow/factory.ts
@@ -5,7 +5,9 @@ import {
   Workflow,
   WorkflowEvent,
 } from "@llamaindex/core/workflow";
+import { Message } from "ai";
 import { ChatMessage, ChatResponseChunk, Settings } from "llamaindex";
+import { getAnnotations } from "../llamaindex/streaming/annotations";
 import {
   createPublisher,
   createResearcher,
@@ -25,19 +27,15 @@ class WriteEvent extends WorkflowEvent<{
 class ReviewEvent extends WorkflowEvent<{ input: string }> {}
 class PublishEvent extends WorkflowEvent<{ input: string }> {}
 
-const prepareChatHistory = (chatHistory: ChatMessage[]) => {
+const prepareChatHistory = (chatHistory: Message[]): ChatMessage[] => {
   // By default, the chat history only contains the assistant and user messages
   // all the agents messages are stored in annotation data which is not visible to the LLM
 
   const MAX_AGENT_MESSAGES = 10;
-
-  // Construct a new agent message from agent messages
-  // Get annotations from assistant messages
-  const agentAnnotations = chatHistory
-    .filter((msg) => msg.role === "assistant")
-    .flatMap((msg) => msg.annotations || [])
-    .filter((annotation) => annotation.type === "agent")
-    .slice(-MAX_AGENT_MESSAGES);
+  const agentAnnotations = getAnnotations<{ agent: string; text: string }>(
+    chatHistory,
+    { role: "assistant", type: "agent" },
+  ).slice(-MAX_AGENT_MESSAGES);
 
   const agentMessages = agentAnnotations
     .map(
@@ -59,13 +57,13 @@ const prepareChatHistory = (chatHistory: ChatMessage[]) => {
       ...chatHistory.slice(0, -1),
       agentMessage,
       chatHistory.slice(-1)[0],
-    ];
+    ] as ChatMessage[];
   }
-  return chatHistory;
+  return chatHistory as ChatMessage[];
 };
 
-export const createWorkflow = (chatHistory: ChatMessage[]) => {
-  const chatHistoryWithAgentMessages = prepareChatHistory(chatHistory);
+export const createWorkflow = (messages: Message[], params?: any) => {
+  const chatHistoryWithAgentMessages = prepareChatHistory(messages);
   const runAgent = async (
     context: Context,
     agent: Workflow,
@@ -123,7 +121,10 @@ Decision (respond with either 'not_publish' or 'publish'):`;
   };
 
   const research = async (context: Context, ev: ResearchEvent) => {
-    const researcher = await createResearcher(chatHistoryWithAgentMessages);
+    const researcher = await createResearcher(
+      chatHistoryWithAgentMessages,
+      params,
+    );
     const researchRes = await runAgent(context, researcher, {
       message: ev.data.input,
     });
diff --git a/templates/components/multiagent/typescript/workflow/single-agent.ts b/templates/components/multiagent/typescript/workflow/single-agent.ts
index 568697df..5344f108 100644
--- a/templates/components/multiagent/typescript/workflow/single-agent.ts
+++ b/templates/components/multiagent/typescript/workflow/single-agent.ts
@@ -143,7 +143,7 @@ export class FunctionCallingAgent extends Workflow {
         fullResponse = chunk;
       }
 
-      if (fullResponse) {
+      if (fullResponse?.options && Object.keys(fullResponse.options).length) {
         memory.put({
           role: "assistant",
           content: "",
diff --git a/templates/components/multiagent/typescript/workflow/tools.ts b/templates/components/multiagent/typescript/workflow/tools.ts
index ac4e5fb9..012da6ae 100644
--- a/templates/components/multiagent/typescript/workflow/tools.ts
+++ b/templates/components/multiagent/typescript/workflow/tools.ts
@@ -4,8 +4,10 @@ import path from "path";
 import { getDataSource } from "../engine";
 import { createTools } from "../engine/tools/index";
 
-const getQueryEngineTool = async (): Promise<QueryEngineTool | null> => {
-  const index = await getDataSource();
+export const getQueryEngineTool = async (
+  params?: any,
+): Promise<QueryEngineTool | null> => {
+  const index = await getDataSource(params);
   if (!index) {
     return null;
   }
-- 
GitLab