diff --git a/apps/docs/docs/modules/agent/openai.mdx b/apps/docs/docs/modules/agent/openai.mdx index 084cea1c9bcb8679c668357468f261947de23ccf..f0130a985e75ca89c79ab1fb8d3f7ced78d4e23e 100644 --- a/apps/docs/docs/modules/agent/openai.mdx +++ b/apps/docs/docs/modules/agent/openai.mdx @@ -82,7 +82,7 @@ const divideFunctionTool = new FunctionTool(divideNumbers, { Now we can create an OpenAIAgent with the function tools. ```ts -const worker = new OpenAIAgent({ +const agent = new OpenAIAgent({ tools: [sumFunctionTool, divideFunctionTool], verbose: true, }); @@ -93,7 +93,7 @@ const worker = new OpenAIAgent({ Now we can chat with the agent. ```ts -const response = await worker.chat({ +const response = await agent.chat({ message: "How much is 5 + 5? then divide by 2", }); diff --git a/apps/docs/docs/modules/agent/react_agent.mdx b/apps/docs/docs/modules/agent/react_agent.mdx new file mode 100644 index 0000000000000000000000000000000000000000..4c5e6cb9731b6199657985120134d536ce08ad40 --- /dev/null +++ b/apps/docs/docs/modules/agent/react_agent.mdx @@ -0,0 +1,203 @@ +# ReAct Agent + +The ReAct agent is an AI agent that can reason over the next action, construct an action command, execute the action, and repeat these steps in an iterative loop until the task is complete. + +In this notebook tutorial, we showcase how to write your ReAct agent using the `llamaindex` package. + +## Setup + +First, you need to install the `llamaindex` package. You can do this by running the following command in your terminal: + +```bash +pnpm i llamaindex +``` + +And then you can import the `OpenAIAgent` and `FunctionTool` from the `llamaindex` package. + +```ts +import { FunctionTool, OpenAIAgent } from "llamaindex"; +``` + +Then we can define a function to sum two numbers and another function to divide two numbers. + +```ts +function sumNumbers({ a, b }: { a: number; b: number }): number { + return a + b; +} + +// Define a function to divide two numbers +function divideNumbers({ a, b }: { a: number; b: number }): number { + return a / b; +} +``` + +## Create a function tool + +Now we can create a function tool from the sum function and another function tool from the divide function. + +For the parameters of the sum function, we can define a JSON schema. + +### JSON Schema + +```ts +const sumJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], +}; + +const divideJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The dividend a to divide", + }, + b: { + type: "number", + description: "The divisor b to divide by", + }, + }, + required: ["a", "b"], +}; + +const sumFunctionTool = new FunctionTool(sumNumbers, { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: sumJSON, +}); + +const divideFunctionTool = new FunctionTool(divideNumbers, { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: divideJSON, +}); +``` + +## Create an ReAct + +Now we can create an OpenAIAgent with the function tools. + +```ts +const agent = new ReActAgent({ + tools: [sumFunctionTool, divideFunctionTool], + verbose: true, +}); +``` + +## Chat with the agent + +Now we can chat with the agent. + +```ts +const response = await agent.chat({ + message: "How much is 5 + 5? then divide by 2", +}); + +console.log(String(response)); +``` + +The output will be: + +```bash +Thought: I need to use a tool to help me answer the question. +Action: sumNumbers +Action Input: {"a":5,"b":5} + +Observation: 10 +Thought: I can answer without using any more tools. +Answer: The sum of 5 and 5 is 10, and when divided by 2, the result is 5. + +The sum of 5 and 5 is 10, and when divided by 2, the result is 5. +``` + +## Full code + +```ts +import { FunctionTool, ReActAgent } from "llamaindex"; + +// Define a function to sum two numbers +function sumNumbers({ a, b }: { a: number; b: number }): number { + return a + b; +} + +// Define a function to divide two numbers +function divideNumbers({ a, b }: { a: number; b: number }): number { + return a / b; +} + +// Define the parameters of the sum function as a JSON schema +const sumJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], +}; + +// Define the parameters of the divide function as a JSON schema +const divideJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The argument a to divide", + }, + b: { + type: "number", + description: "The argument b to divide", + }, + }, + required: ["a", "b"], +}; + +async function main() { + // Create a function tool from the sum function + const sumFunctionTool = new FunctionTool(sumNumbers, { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: sumJSON, + }); + + // Create a function tool from the divide function + const divideFunctionTool = new FunctionTool(divideNumbers, { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: divideJSON, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [sumFunctionTool, divideFunctionTool], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "I want to sum 5 and 5 and then divide by 2", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); +``` diff --git a/examples/agent/react_agent.ts b/examples/agent/react_agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..5cb1c550dba7bda5ad775688bbbd1e91bd357188 --- /dev/null +++ b/examples/agent/react_agent.ts @@ -0,0 +1,77 @@ +import { FunctionTool, ReActAgent } from "llamaindex"; + +// Define a function to sum two numbers +function sumNumbers({ a, b }: { a: number; b: number }): number { + return a + b; +} + +// Define a function to divide two numbers +function divideNumbers({ a, b }: { a: number; b: number }): number { + return a / b; +} + +// Define the parameters of the sum function as a JSON schema +const sumJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], +}; + +const divideJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The dividend", + }, + b: { + type: "number", + description: "The divisor", + }, + }, + required: ["a", "b"], +}; + +async function main() { + // Create a function tool from the sum function + const functionTool = new FunctionTool(sumNumbers, { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: sumJSON, + }); + + // Create a function tool from the divide function + const functionTool2 = new FunctionTool(divideNumbers, { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: divideJSON, + }); + + // Create an OpenAIAgent with the function tools + const agent = new ReActAgent({ + tools: [functionTool, functionTool2], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "I want to sum 5 and 5 and divide by 2", + toolChoice: "auto", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); diff --git a/packages/core/src/agent/index.ts b/packages/core/src/agent/index.ts index 596146258fbfcd2782d3de29de100862ed882508..1ef9f464e53137db00a74a7dfcfaf359a6af8fad 100644 --- a/packages/core/src/agent/index.ts +++ b/packages/core/src/agent/index.ts @@ -1,2 +1,5 @@ export * from "./openai/base"; export * from "./openai/worker"; +export * from "./react/base"; +export * from "./react/worker"; +export * from "./types"; diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 1119ca1c9e6bda6d2686d182375d915ae8321699..20176f51941fa14684b97127f24dcd58fa495463 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -95,7 +95,7 @@ export class OpenAIAgentWorker implements AgentWorker { public prefixMessages: ChatMessage[]; public callbackManager: CallbackManager | undefined; - private _getTools: (input: string) => BaseTool[]; + private _getTools: (input: string) => Promise<BaseTool[]>; /** * Initialize. @@ -118,12 +118,12 @@ export class OpenAIAgentWorker implements AgentWorker { if (tools.length > 0 && toolRetriever) { throw new Error("Cannot specify both tools and tool_retriever"); } else if (tools.length > 0) { - this._getTools = () => tools; + this._getTools = async () => tools; } else if (toolRetriever) { - // @ts-ignore - this._getTools = (message: string) => toolRetriever.retrieve(message); + this._getTools = async (message: string) => + toolRetriever.retrieve(message); } else { - this._getTools = () => []; + this._getTools = async () => []; } } @@ -298,7 +298,7 @@ export class OpenAIAgentWorker implements AgentWorker { * @param input: input * @returns: tools */ - getTools(input: string): BaseTool[] { + async getTools(input: string): Promise<BaseTool[]> { return this._getTools(input); } @@ -308,7 +308,7 @@ export class OpenAIAgentWorker implements AgentWorker { mode: ChatResponseMode = ChatResponseMode.WAIT, toolChoice: string | { [key: string]: any } = "auto", ): Promise<TaskStepOutput> { - const tools = this.getTools(task.input); + const tools = await this.getTools(task.input); if (step.input) { addUserStepToMemory(step, task.extraState.newMemory, this._verbose); diff --git a/packages/core/src/agent/react/base.ts b/packages/core/src/agent/react/base.ts new file mode 100644 index 0000000000000000000000000000000000000000..ecafcff382ebbdccb6652ecb1b14a9bd797dbe6c --- /dev/null +++ b/packages/core/src/agent/react/base.ts @@ -0,0 +1,54 @@ +import { CallbackManager } from "../../callbacks/CallbackManager"; +import { ChatMessage, LLM } from "../../llm"; +import { ObjectRetriever } from "../../objects/base"; +import { BaseTool } from "../../types"; +import { AgentRunner } from "../runner/base"; +import { ReActAgentWorker } from "./worker"; + +type ReActAgentParams = { + tools: BaseTool[]; + llm?: LLM; + memory?: any; + prefixMessages?: ChatMessage[]; + verbose?: boolean; + maxInteractions?: number; + defaultToolChoice?: string; + callbackManager?: CallbackManager; + toolRetriever?: ObjectRetriever<BaseTool>; +}; + +/** + * An agent that uses OpenAI's API to generate text. + * + * @category OpenAI + */ +export class ReActAgent extends AgentRunner { + constructor({ + tools, + llm, + memory, + prefixMessages, + verbose, + maxInteractions = 10, + defaultToolChoice = "auto", + callbackManager, + toolRetriever, + }: Partial<ReActAgentParams>) { + const stepEngine = new ReActAgentWorker({ + tools: tools ?? [], + callbackManager, + llm, + maxInteractions, + toolRetriever, + verbose, + }); + + super({ + agentWorker: stepEngine, + memory, + callbackManager, + defaultToolChoice, + chatHistory: prefixMessages, + }); + } +} diff --git a/packages/core/src/agent/react/formatter.ts b/packages/core/src/agent/react/formatter.ts new file mode 100644 index 0000000000000000000000000000000000000000..89a4be561fec9944d2d56fca5b772a1f36986dcf --- /dev/null +++ b/packages/core/src/agent/react/formatter.ts @@ -0,0 +1,83 @@ +import { ChatMessage } from "../../llm"; +import { BaseTool } from "../../types"; +import { getReactChatSystemHeader } from "./prompts"; +import { BaseReasoningStep, ObservationReasoningStep } from "./types"; + +function getReactToolDescriptions(tools: BaseTool[]): string[] { + const toolDescs: string[] = []; + for (const tool of tools) { + // @ts-ignore + const toolDesc = `> Tool Name: ${tool.metadata.name}\nTool Description: ${tool.metadata.description}\nTool Args: ${JSON.stringify(tool?.metadata?.parameters?.properties)}\n`; + toolDescs.push(toolDesc); + } + return toolDescs; +} + +export interface BaseAgentChatFormatter { + format( + tools: BaseTool[], + chatHistory: ChatMessage[], + currentReasoning?: BaseReasoningStep[], + ): ChatMessage[]; +} + +export class ReActChatFormatter implements BaseAgentChatFormatter { + systemHeader: string = ""; + context: string = "'"; + + constructor(init?: Partial<ReActChatFormatter>) { + Object.assign(this, init); + } + + format( + tools: BaseTool[], + chatHistory: ChatMessage[], + currentReasoning?: BaseReasoningStep[], + ): ChatMessage[] { + currentReasoning = currentReasoning ?? []; + + const formatArgs = { + toolDesc: getReactToolDescriptions(tools).join("\n"), + toolNames: tools.map((tool) => tool.metadata.name).join(", "), + context: "", + }; + + if (this.context) { + formatArgs["context"] = this.context; + } + + const reasoningHistory = []; + + for (const reasoningStep of currentReasoning) { + let message: ChatMessage | undefined; + + if (reasoningStep instanceof ObservationReasoningStep) { + message = { + content: reasoningStep.getContent(), + role: "user", + }; + } else { + message = { + content: reasoningStep.getContent(), + role: "system", + }; + } + + reasoningHistory.push(message); + } + + const systemContent = getReactChatSystemHeader({ + toolDesc: formatArgs.toolDesc, + toolNames: formatArgs.toolNames, + }); + + return [ + { + content: systemContent, + role: "system", + }, + ...reasoningHistory, + ...chatHistory, + ]; + } +} diff --git a/packages/core/src/agent/react/outputParser.ts b/packages/core/src/agent/react/outputParser.ts new file mode 100644 index 0000000000000000000000000000000000000000..3ed1d7be3c9edeb76c241beb443b7dac6378d1de --- /dev/null +++ b/packages/core/src/agent/react/outputParser.ts @@ -0,0 +1,105 @@ +import { + ActionReasoningStep, + BaseOutputParser, + BaseReasoningStep, + ResponseReasoningStep, +} from "./types"; + +function extractJsonStr(text: string): string { + const pattern = /\{.*\}/s; + const match = text.match(pattern); + + if (!match) { + throw new Error(`Could not extract json string from output: ${text}`); + } + + return match[0]; +} + +function extractToolUse(inputText: string): [string, string, string] { + const pattern = + /\s*Thought: (.*?)\nAction: ([a-zA-Z0-9_]+).*?\nAction Input: .*?(\{.*?\})/s; + + const match = inputText.match(pattern); + + if (!match) { + throw new Error(`Could not extract tool use from input text: ${inputText}`); + } + + const thought = match[1].trim(); + const action = match[2].trim(); + const actionInput = match[3].trim(); + return [thought, action, actionInput]; +} + +function actionInputParser(jsonStr: string): object { + const processedString = jsonStr.replace(/(?<!\w)\'|\'(?!\w)/g, '"'); + const pattern = /"(\w+)":\s*"([^"]*)"/g; + const matches = [...processedString.matchAll(pattern)]; + return Object.fromEntries(matches); +} + +function extractFinalResponse(inputText: string): [string, string] { + const pattern = /\s*Thought:(.*?)Answer:(.*?)(?:$)/s; + + const match = inputText.match(pattern); + + if (!match) { + throw new Error( + `Could not extract final answer from input text: ${inputText}`, + ); + } + + const thought = match[1].trim(); + const answer = match[2].trim(); + return [thought, answer]; +} + +export class ReActOutputParser extends BaseOutputParser { + parse(output: string, isStreaming: boolean = false): BaseReasoningStep { + if (!output.includes("Thought:")) { + // NOTE: handle the case where the agent directly outputs the answer + // instead of following the thought-answer format + return new ResponseReasoningStep({ + thought: "(Implicit) I can answer without any more tools!", + response: output, + isStreaming, + }); + } + + if (output.includes("Answer:")) { + const [thought, answer] = extractFinalResponse(output); + return new ResponseReasoningStep({ + thought: thought, + response: answer, + isStreaming, + }); + } + + if (output.includes("Action:")) { + const [thought, action, action_input] = extractToolUse(output); + const json_str = extractJsonStr(action_input); + + // First we try json, if this fails we use ast + let actionInputDict; + + try { + actionInputDict = JSON.parse(json_str); + } catch (e) { + actionInputDict = actionInputParser(json_str); + } + + return new ActionReasoningStep({ + thought: thought, + action: action, + actionInput: actionInputDict, + }); + } + + throw new Error(`Could not parse output: ${output}`); + } + + format(output: string): string { + throw new Error("Not implemented"); + } +} diff --git a/packages/core/src/agent/react/prompts.ts b/packages/core/src/agent/react/prompts.ts new file mode 100644 index 0000000000000000000000000000000000000000..75e98a468abd69254bf308639ffec99d791b7629 --- /dev/null +++ b/packages/core/src/agent/react/prompts.ts @@ -0,0 +1,56 @@ +type ReactChatSystemHeaderParams = { + toolDesc: string; + toolNames: string; +}; + +export const getReactChatSystemHeader = ({ + toolDesc, + toolNames, +}: ReactChatSystemHeaderParams) => + `You are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses. + +## Tools +You have access to a wide variety of tools. You are responsible for using +the tools in any sequence you deem appropriate to complete the task at hand. +This may require breaking the task into subtasks and using different tools +to complete each subtask. + +You have access to the following tools: +${toolDesc} + +## Output Format +To answer the question, please use the following format. + +""" +Thought: I need to use a tool to help me answer the question. +Action: tool name (one of ${toolNames}) if using a tool. +Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}}) +""" + +Please ALWAYS start with a Thought. + +Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}. + +If this format is used, the user will respond in the following format: + +"""" +Observation: tool response +"""" + +You should keep repeating the above format until you have enough information +to answer the question without using any more tools. At that point, you MUST respond +in the one of the following two formats: + +"""" +Thought: I can answer without using any more tools. +Answer: [your answer here] +"""" + +"""" +Thought: I cannot answer the question with the provided tools. +Answer: Sorry, I cannot answer your query. +"""" + +## Current Conversation +Below is the current conversation consisting of interleaving human and assistant messages. +`; diff --git a/packages/core/src/agent/react/types.ts b/packages/core/src/agent/react/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..04e7889e2bebda15a1d33ad831e9fb309ba7d068 --- /dev/null +++ b/packages/core/src/agent/react/types.ts @@ -0,0 +1,88 @@ +import { ChatMessage } from "../../llm"; + +export interface BaseReasoningStep { + getContent(): string; + isDone(): boolean; +} + +export class ObservationReasoningStep implements BaseReasoningStep { + observation: string; + + constructor(init?: Partial<ObservationReasoningStep>) { + this.observation = init?.observation ?? ""; + } + + getContent(): string { + return `Observation: ${this.observation}`; + } + + isDone(): boolean { + return false; + } +} + +export class ActionReasoningStep implements BaseReasoningStep { + thought: string; + action: string; + actionInput: Record<string, any>; + + constructor(init?: Partial<ActionReasoningStep>) { + this.thought = init?.thought ?? ""; + this.action = init?.action ?? ""; + this.actionInput = init?.actionInput ?? {}; + } + + getContent(): string { + return `Thought: ${this.thought}\nAction: ${this.action}\nAction Input: ${JSON.stringify(this.actionInput)}`; + } + + isDone(): boolean { + return false; + } +} + +export abstract class BaseOutputParser { + abstract parse(output: string, isStreaming?: boolean): BaseReasoningStep; + + format(output: string) { + return output; + } + + formatMessages(messages: ChatMessage[]): ChatMessage[] { + if (messages) { + if (messages[0].role === "system") { + messages[0].content = this.format(messages[0].content || ""); + } else { + messages[messages.length - 1].content = this.format( + messages[messages.length - 1].content || "", + ); + } + } + + return messages; + } +} + +export class ResponseReasoningStep implements BaseReasoningStep { + thought: string; + response: string; + isStreaming: boolean = false; + + constructor(init?: Partial<ResponseReasoningStep>) { + this.thought = init?.thought ?? ""; + this.response = init?.response ?? ""; + this.isStreaming = init?.isStreaming ?? false; + } + + getContent(): string { + if (this.isStreaming) { + return `Thought: ${this.thought}\nAnswer (Starts With): ${this.response} ...`; + } else { + return `Thought: ${this.thought}\nAnswer: ${this.response}`; + } + } + + isDone(): boolean { + return true; + } +} diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts new file mode 100644 index 0000000000000000000000000000000000000000..cb0511ccc94523d328b9e20bf98b109e12334885 --- /dev/null +++ b/packages/core/src/agent/react/worker.ts @@ -0,0 +1,395 @@ +import { randomUUID } from "crypto"; +import { CallbackManager } from "../../callbacks/CallbackManager"; +import { AgentChatResponse } from "../../engines/chat"; +import { ChatResponse, LLM, OpenAI } from "../../llm"; +import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer"; +import { ObjectRetriever } from "../../objects/base"; +import { ToolOutput } from "../../tools"; +import { BaseTool } from "../../types"; +import { AgentWorker, Task, TaskStep, TaskStepOutput } from "../types"; +import { ReActChatFormatter } from "./formatter"; +import { ReActOutputParser } from "./outputParser"; +import { + ActionReasoningStep, + BaseReasoningStep, + ObservationReasoningStep, + ResponseReasoningStep, +} from "./types"; + +type ReActAgentWorkerParams = { + tools: BaseTool[]; + llm?: LLM; + maxInteractions?: number; + reactChatFormatter?: ReActChatFormatter | undefined; + outputParser?: ReActOutputParser | undefined; + callbackManager?: CallbackManager | undefined; + verbose?: boolean | undefined; + toolRetriever?: ObjectRetriever<BaseTool> | undefined; +}; + +/** + * + * @param step + * @param memory + * @param currentReasoning + * @param verbose + */ +function addUserStepToReasoning( + step: TaskStep, + memory: ChatMemoryBuffer, + currentReasoning: BaseReasoningStep[], + verbose: boolean = false, +): void { + if (step.stepState.isFirst) { + memory.put({ + content: step.input, + role: "user", + }); + step.stepState.isFirst = false; + } else { + const reasoningStep = new ObservationReasoningStep({ + observation: step.input ?? undefined, + }); + currentReasoning.push(reasoningStep); + if (verbose) { + console.log(`Added user message to memory: ${step.input}`); + } + } +} + +/** + * ReAct agent worker. + */ +export class ReActAgentWorker implements AgentWorker { + llm: LLM; + verbose: boolean; + + maxInteractions: number = 10; + reactChatFormatter: ReActChatFormatter; + outputParser: ReActOutputParser; + + callbackManager: CallbackManager; + + _getTools: (message: string) => Promise<BaseTool[]>; + + constructor({ + tools, + llm, + maxInteractions, + reactChatFormatter, + outputParser, + callbackManager, + verbose, + toolRetriever, + }: ReActAgentWorkerParams) { + this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-1106" }); + this.callbackManager = callbackManager || new CallbackManager(); + + this.maxInteractions = maxInteractions ?? 10; + this.reactChatFormatter = reactChatFormatter ?? new ReActChatFormatter(); + this.outputParser = outputParser ?? new ReActOutputParser(); + this.verbose = verbose || false; + + if (tools.length > 0 && toolRetriever) { + throw new Error("Cannot specify both tools and tool_retriever"); + } else if (tools.length > 0) { + this._getTools = async () => tools; + } else if (toolRetriever) { + this._getTools = async (message: string) => + toolRetriever.retrieve(message); + } else { + this._getTools = async () => []; + } + } + + /** + * Initialize a task step. + * @param task - task + * @param kwargs - keyword arguments + * @returns - task step + */ + initializeStep(task: Task, kwargs?: any): TaskStep { + const sources: ToolOutput[] = []; + const currentReasoning: BaseReasoningStep[] = []; + const newMemory = new ChatMemoryBuffer(); + + const taskState = { + sources, + currentReasoning, + newMemory, + }; + + task.extraState = { + ...task.extraState, + ...taskState, + }; + + return new TaskStep(task.taskId, randomUUID(), task.input, { + isFirst: true, + }); + } + + /** + * Extract reasoning step from chat response. + * @param output - chat response + * @param isStreaming - whether the chat response is streaming + * @returns - [message content, reasoning steps, is done] + */ + extractReasoningStep( + output: ChatResponse, + isStreaming: boolean, + ): [string, BaseReasoningStep[], boolean] { + if (!output.message.content) { + throw new Error("Got empty message."); + } + + const messageContent = output.message.content; + const currentReasoning: BaseReasoningStep[] = []; + + let reasoningStep; + + try { + reasoningStep = this.outputParser.parse( + messageContent, + isStreaming, + ) as ActionReasoningStep; + } catch (e) { + throw new Error(`Could not parse output: ${e}`); + } + + if (this.verbose) { + console.log(`${reasoningStep.getContent()}\n`); + } + + currentReasoning.push(reasoningStep); + + if (reasoningStep.isDone()) { + return [messageContent, currentReasoning, true]; + } + + const actionReasoningStep = new ActionReasoningStep({ + thought: reasoningStep.getContent(), + action: reasoningStep.action, + actionInput: reasoningStep.actionInput, + }); + + if (!(actionReasoningStep instanceof ActionReasoningStep)) { + throw new Error(`Expected ActionReasoningStep, got ${reasoningStep}`); + } + + return [messageContent, currentReasoning, false]; + } + + /** + * Process actions. + * @param task - task + * @param tools - tools + * @param output - chat response + * @param isStreaming - whether the chat response is streaming + * @returns - [reasoning steps, is done] + */ + async _processActions( + task: Task, + tools: BaseTool[], + output: ChatResponse, + isStreaming: boolean = false, + ): Promise<[BaseReasoningStep[], boolean]> { + const toolsDict: Record<string, BaseTool> = {}; + + for (const tool of tools) { + toolsDict[tool.metadata.name] = tool; + } + + const [_, currentReasoning, isDone] = this.extractReasoningStep( + output, + isStreaming, + ); + + if (isDone) { + return [currentReasoning, true]; + } + + const reasoningStep = currentReasoning[ + currentReasoning.length - 1 + ] as ActionReasoningStep; + + const actionReasoningStep = new ActionReasoningStep({ + thought: reasoningStep.getContent(), + action: reasoningStep.action, + actionInput: reasoningStep.actionInput, + }); + + const tool = toolsDict[actionReasoningStep.action]; + + const toolOutput = await tool?.call?.(actionReasoningStep.actionInput); + + task.extraState.sources.push( + new ToolOutput( + toolOutput, + tool.metadata.name, + actionReasoningStep.actionInput, + toolOutput, + ), + ); + + const observationStep = new ObservationReasoningStep({ + observation: toolOutput, + }); + + currentReasoning.push(observationStep); + + if (this.verbose) { + console.log(`${observationStep.getContent()}`); + } + + return [currentReasoning, false]; + } + + /** + * Get response. + * @param currentReasoning - current reasoning steps + * @param sources - tool outputs + * @returns - agent chat response + */ + _getResponse( + currentReasoning: BaseReasoningStep[], + sources: ToolOutput[], + ): AgentChatResponse { + if (currentReasoning.length === 0) { + throw new Error("No reasoning steps were taken."); + } else if (currentReasoning.length === this.maxInteractions) { + throw new Error("Reached max iterations."); + } + + const responseStep = currentReasoning[currentReasoning.length - 1]; + + let responseStr: string; + + if (responseStep instanceof ResponseReasoningStep) { + responseStr = responseStep.response; + } else { + responseStr = responseStep.getContent(); + } + + return new AgentChatResponse(responseStr, sources); + } + + /** + * Get task step response. + * @param agentResponse - agent chat response + * @param step - task step + * @param isDone - whether the task is done + * @returns - task step output + */ + _getTaskStepResponse( + agentResponse: AgentChatResponse, + step: TaskStep, + isDone: boolean, + ): TaskStepOutput { + let newSteps: TaskStep[] = []; + + if (isDone) { + newSteps = []; + } else { + newSteps = [step.getNextStep(randomUUID(), undefined)]; + } + + return new TaskStepOutput(agentResponse, step, newSteps, isDone); + } + + /** + * Run a task step. + * @param step - task step + * @param task - task + * @param kwargs - keyword arguments + * @returns - task step output + */ + async _runStep( + step: TaskStep, + task: Task, + kwargs?: any, + ): Promise<TaskStepOutput> { + if (step.input) { + addUserStepToReasoning( + step, + task.extraState.newMemory, + task.extraState.currentReasoning, + this.verbose, + ); + } + + const tools = await this._getTools(task.input); + + const inputChat = this.reactChatFormatter.format( + tools, + [...task.memory.get(), ...task.extraState.newMemory.get()], + task.extraState.currentReasoning, + ); + + const chatResponse = await this.llm.chat({ + messages: inputChat, + }); + + const [reasoningSteps, isDone] = await this._processActions( + task, + tools, + chatResponse, + ); + + task.extraState.currentReasoning.push(...reasoningSteps); + + const agentResponse = this._getResponse( + task.extraState.currentReasoning, + task.extraState.sources, + ); + + if (isDone) { + task.extraState.newMemory.put({ + content: agentResponse.response, + role: "assistant", + }); + } + + return this._getTaskStepResponse(agentResponse, step, isDone); + } + + /** + * Run a task step. + * @param step - task step + * @param task - task + * @param kwargs - keyword arguments + * @returns - task step output + */ + async runStep( + step: TaskStep, + task: Task, + kwargs?: any, + ): Promise<TaskStepOutput> { + return await this._runStep(step, task); + } + + /** + * Run a task step. + * @param step - task step + * @param task - task + * @param kwargs - keyword arguments + * @returns - task step output + */ + streamStep( + step: TaskStep, + task: Task, + kwargs?: any, + ): Promise<TaskStepOutput> { + throw new Error("Method not implemented."); + } + + /** + * Finalize a task. + * @param task - task + * @param kwargs - keyword arguments + */ + finalizeTask(task: Task, kwargs?: any): void { + task.memory.set(task.memory.get() + task.extraState.newMemory.get()); + task.extraState.newMemory.reset(); + } +} diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 29e975003e08971430cf0280e6994ee12265f9fe..777d24e6da903c54e42b3e4322c9f50b86a31384 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -228,7 +228,6 @@ export class OpenAI extends BaseLLM { params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { const { messages, parentEvent, stream, tools, toolChoice } = params; - let baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { model: this.model, temperature: this.temperature, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a1c8d0362fea83afadb8b9901dbf038bc4791528..44a4750e511a22a44d6dc4678a95e132ba9de590 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -10213,6 +10213,55 @@ packages: wrap-ansi: 9.0.0 dev: true + /llamaindex@0.1.9(typescript@5.3.3): + resolution: {integrity: sha512-MAMGV5MXXcJ4rSV2kqCZENf7B+q1zTwPnHpnHJgEiEzP5+djNdLmbz/zaCmxpB8wgNNLUem9iJt53iwDBJ4ZBA==} + engines: {node: '>=18.0.0'} + dependencies: + '@anthropic-ai/sdk': 0.12.4 + '@datastax/astra-db-ts': 0.1.4 + '@mistralai/mistralai': 0.0.10 + '@notionhq/client': 2.2.14 + '@pinecone-database/pinecone': 1.1.3 + '@qdrant/js-client-rest': 1.7.0(typescript@5.3.3) + '@xenova/transformers': 2.14.1 + assemblyai: 4.2.1 + chromadb: 1.7.3(openai@4.26.0) + file-type: 18.7.0 + js-tiktoken: 1.0.8 + lodash: 4.17.21 + mammoth: 1.6.0 + md-utils-ts: 2.0.0 + mongodb: 6.3.0 + notion-md-crawler: 0.0.2 + openai: 4.26.0 + papaparse: 5.4.1 + pathe: 1.1.2 + pdf2json: 3.0.5 + pg: 8.11.3 + pgvector: 0.1.7 + portkey-ai: 0.1.16 + rake-modified: 1.0.8 + replicate: 0.25.2 + string-strip-html: 13.4.5 + wink-nlp: 1.14.3 + transitivePeerDependencies: + - '@aws-sdk/credential-providers' + - '@google/generative-ai' + - '@mongodb-js/zstd' + - bufferutil + - cohere-ai + - debug + - encoding + - gcp-metadata + - kerberos + - mongodb-client-encryption + - pg-native + - snappy + - socks + - typescript + - utf-8-validate + dev: false + /load-yaml-file@0.2.0: resolution: {integrity: sha512-OfCBkGEw4nN6JLtgRidPX6QxjBQGQf72q3si2uvqyFEMbycSFFHwAZeXx6cJgFM9wmLrf9zBwCP3Ivqa+LLZPw==} engines: {node: '>=6'}