From 3d414883012dd9ecda9bf17e5f2ddf200501574c Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:37:55 +0700 Subject: [PATCH] feat: use selected llamacloud for multiagent (#359) --- .changeset/stupid-paws-push.md | 5 ++++ .../typescript/streaming/annotations.ts | 23 +++++++++++++++ .../typescript/express/chat.controller.ts | 7 ++--- .../multiagent/typescript/nextjs/route.ts | 7 ++--- .../multiagent/typescript/workflow/agents.ts | 21 +++++++++----- .../multiagent/typescript/workflow/factory.ts | 29 ++++++++++--------- .../typescript/workflow/single-agent.ts | 2 +- .../multiagent/typescript/workflow/tools.ts | 6 ++-- 8 files changed, 67 insertions(+), 33 deletions(-) create mode 100644 .changeset/stupid-paws-push.md diff --git a/.changeset/stupid-paws-push.md b/.changeset/stupid-paws-push.md new file mode 100644 index 00000000..716a0ef5 --- /dev/null +++ b/.changeset/stupid-paws-push.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +feat: use selected llamacloud for multiagent diff --git a/templates/components/llamaindex/typescript/streaming/annotations.ts b/templates/components/llamaindex/typescript/streaming/annotations.ts index 13842c7a..10e6f52c 100644 --- a/templates/components/llamaindex/typescript/streaming/annotations.ts +++ b/templates/components/llamaindex/typescript/streaming/annotations.ts @@ -172,3 +172,26 @@ function getValidAnnotation(annotation: JSONValue): Annotation { } return { type: annotation.type, data: annotation.data }; } + +// validate and get all annotations of a specific type or role from the frontend messages +export function getAnnotations< + T extends Annotation["data"] = Annotation["data"], +>( + messages: Message[], + options?: { + role?: Message["role"]; // message role + type?: Annotation["type"]; // annotation type + }, +): { + type: string; + data: T; +}[] { + const messagesByRole = options?.role + ? messages.filter((msg) => msg.role === options?.role) + : messages; + const annotations = getAllAnnotations(messagesByRole); + const annotationsByType = options?.type + ? annotations.filter((a) => a.type === options.type) + : annotations; + return annotationsByType as { type: string; data: T }[]; +} diff --git a/templates/components/multiagent/typescript/express/chat.controller.ts b/templates/components/multiagent/typescript/express/chat.controller.ts index 46be6d78..8dfaf6c4 100644 --- a/templates/components/multiagent/typescript/express/chat.controller.ts +++ b/templates/components/multiagent/typescript/express/chat.controller.ts @@ -1,13 +1,13 @@ import { StopEvent } from "@llamaindex/core/workflow"; import { Message, streamToResponse } from "ai"; import { Request, Response } from "express"; -import { ChatMessage, ChatResponseChunk } from "llamaindex"; +import { ChatResponseChunk } from "llamaindex"; import { createWorkflow } from "./workflow/factory"; import { toDataStream, workflowEventsToStreamData } from "./workflow/stream"; export const chat = async (req: Request, res: Response) => { try { - const { messages }: { messages: Message[] } = req.body; + const { messages, data }: { messages: Message[]; data?: any } = req.body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ @@ -16,8 +16,7 @@ export const chat = async (req: Request, res: Response) => { }); } - const chatHistory = messages as ChatMessage[]; - const agent = createWorkflow(chatHistory); + const agent = createWorkflow(messages, data); const result = agent.run<AsyncGenerator<ChatResponseChunk>>( userMessage.content, ) as unknown as Promise<StopEvent<AsyncGenerator<ChatResponseChunk>>>; diff --git a/templates/components/multiagent/typescript/nextjs/route.ts b/templates/components/multiagent/typescript/nextjs/route.ts index 04b40339..2f93e0f7 100644 --- a/templates/components/multiagent/typescript/nextjs/route.ts +++ b/templates/components/multiagent/typescript/nextjs/route.ts @@ -1,7 +1,7 @@ import { initObservability } from "@/app/observability"; import { StopEvent } from "@llamaindex/core/workflow"; import { Message, StreamingTextResponse } from "ai"; -import { ChatMessage, ChatResponseChunk } from "llamaindex"; +import { ChatResponseChunk } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { initSettings } from "./engine/settings"; import { createWorkflow } from "./workflow/factory"; @@ -16,7 +16,7 @@ export const dynamic = "force-dynamic"; export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages }: { messages: Message[] } = body; + const { messages, data }: { messages: Message[]; data?: any } = body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return NextResponse.json( @@ -28,8 +28,7 @@ export async function POST(request: NextRequest) { ); } - const chatHistory = messages as ChatMessage[]; - const agent = createWorkflow(chatHistory); + const agent = createWorkflow(messages, data); // TODO: fix type in agent.run in LITS const result = agent.run<AsyncGenerator<ChatResponseChunk>>( userMessage.content, diff --git a/templates/components/multiagent/typescript/workflow/agents.ts b/templates/components/multiagent/typescript/workflow/agents.ts index 6af2bf94..71f3123c 100644 --- a/templates/components/multiagent/typescript/workflow/agents.ts +++ b/templates/components/multiagent/typescript/workflow/agents.ts @@ -1,14 +1,19 @@ import { ChatMessage } from "llamaindex"; import { FunctionCallingAgent } from "./single-agent"; -import { lookupTools } from "./tools"; +import { getQueryEngineTool, lookupTools } from "./tools"; -export const createResearcher = async (chatHistory: ChatMessage[]) => { - const tools = await lookupTools([ - "query_index", - "wikipedia_tool", - "duckduckgo_search", - "image_generator", - ]); +export const createResearcher = async ( + chatHistory: ChatMessage[], + params?: any, +) => { + const queryEngineTool = await getQueryEngineTool(params); + const tools = ( + await lookupTools([ + "wikipedia_tool", + "duckduckgo_search", + "image_generator", + ]) + ).concat(queryEngineTool ? [queryEngineTool] : []); return new FunctionCallingAgent({ name: "researcher", diff --git a/templates/components/multiagent/typescript/workflow/factory.ts b/templates/components/multiagent/typescript/workflow/factory.ts index 2aef2c25..0e341ca2 100644 --- a/templates/components/multiagent/typescript/workflow/factory.ts +++ b/templates/components/multiagent/typescript/workflow/factory.ts @@ -5,7 +5,9 @@ import { Workflow, WorkflowEvent, } from "@llamaindex/core/workflow"; +import { Message } from "ai"; import { ChatMessage, ChatResponseChunk, Settings } from "llamaindex"; +import { getAnnotations } from "../llamaindex/streaming/annotations"; import { createPublisher, createResearcher, @@ -25,19 +27,15 @@ class WriteEvent extends WorkflowEvent<{ class ReviewEvent extends WorkflowEvent<{ input: string }> {} class PublishEvent extends WorkflowEvent<{ input: string }> {} -const prepareChatHistory = (chatHistory: ChatMessage[]) => { +const prepareChatHistory = (chatHistory: Message[]): ChatMessage[] => { // By default, the chat history only contains the assistant and user messages // all the agents messages are stored in annotation data which is not visible to the LLM const MAX_AGENT_MESSAGES = 10; - - // Construct a new agent message from agent messages - // Get annotations from assistant messages - const agentAnnotations = chatHistory - .filter((msg) => msg.role === "assistant") - .flatMap((msg) => msg.annotations || []) - .filter((annotation) => annotation.type === "agent") - .slice(-MAX_AGENT_MESSAGES); + const agentAnnotations = getAnnotations<{ agent: string; text: string }>( + chatHistory, + { role: "assistant", type: "agent" }, + ).slice(-MAX_AGENT_MESSAGES); const agentMessages = agentAnnotations .map( @@ -59,13 +57,13 @@ const prepareChatHistory = (chatHistory: ChatMessage[]) => { ...chatHistory.slice(0, -1), agentMessage, chatHistory.slice(-1)[0], - ]; + ] as ChatMessage[]; } - return chatHistory; + return chatHistory as ChatMessage[]; }; -export const createWorkflow = (chatHistory: ChatMessage[]) => { - const chatHistoryWithAgentMessages = prepareChatHistory(chatHistory); +export const createWorkflow = (messages: Message[], params?: any) => { + const chatHistoryWithAgentMessages = prepareChatHistory(messages); const runAgent = async ( context: Context, agent: Workflow, @@ -123,7 +121,10 @@ Decision (respond with either 'not_publish' or 'publish'):`; }; const research = async (context: Context, ev: ResearchEvent) => { - const researcher = await createResearcher(chatHistoryWithAgentMessages); + const researcher = await createResearcher( + chatHistoryWithAgentMessages, + params, + ); const researchRes = await runAgent(context, researcher, { message: ev.data.input, }); diff --git a/templates/components/multiagent/typescript/workflow/single-agent.ts b/templates/components/multiagent/typescript/workflow/single-agent.ts index 568697df..5344f108 100644 --- a/templates/components/multiagent/typescript/workflow/single-agent.ts +++ b/templates/components/multiagent/typescript/workflow/single-agent.ts @@ -143,7 +143,7 @@ export class FunctionCallingAgent extends Workflow { fullResponse = chunk; } - if (fullResponse) { + if (fullResponse?.options && Object.keys(fullResponse.options).length) { memory.put({ role: "assistant", content: "", diff --git a/templates/components/multiagent/typescript/workflow/tools.ts b/templates/components/multiagent/typescript/workflow/tools.ts index ac4e5fb9..012da6ae 100644 --- a/templates/components/multiagent/typescript/workflow/tools.ts +++ b/templates/components/multiagent/typescript/workflow/tools.ts @@ -4,8 +4,10 @@ import path from "path"; import { getDataSource } from "../engine"; import { createTools } from "../engine/tools/index"; -const getQueryEngineTool = async (): Promise<QueryEngineTool | null> => { - const index = await getDataSource(); +export const getQueryEngineTool = async ( + params?: any, +): Promise<QueryEngineTool | null> => { + const index = await getDataSource(params); if (!index) { return null; } -- GitLab