diff --git a/examples/agent/openai.ts b/examples/agent/openai.ts index 3e19011995ba117db71e5adcd0c892b38560005b..3bc65c2b37c1b5d472f1f6ea4ba010bb4e81c077 100644 --- a/examples/agent/openai.ts +++ b/examples/agent/openai.ts @@ -1,13 +1,13 @@ import { FunctionTool, OpenAIAgent } from "llamaindex"; // Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; +function sumNumbers({ a, b }: { a: number; b: number }) { + return `${a + b}`; } // Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }): number { - return a / b; +function divideNumbers({ a, b }: { a: number; b: number }) { + return `${a / b}`; } // Define the parameters of the sum function as a JSON schema @@ -24,7 +24,7 @@ const sumJSON = { }, }, required: ["a", "b"], -}; +} as const; const divideJSON = { type: "object", @@ -39,7 +39,7 @@ const divideJSON = { }, }, required: ["a", "b"], -}; +} as const; async function main() { // Create a function tool from the sum function diff --git a/examples/agent/react_agent.ts b/examples/agent/react_agent.ts index a177b3835bbe26f19ef6783259fd12ad26d200f1..d86fe47073dfb9c78679343b06d1a10fce2d3393 100644 --- a/examples/agent/react_agent.ts +++ b/examples/agent/react_agent.ts @@ -1,13 +1,13 @@ import { Anthropic, FunctionTool, ReActAgent } from "llamaindex"; // Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; +function sumNumbers({ a, b }: { a: number; b: number }) { + return `${a + b}`; } // Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }): number { - return a / b; +function divideNumbers({ a, b }: { a: number; b: number }) { + return `${a / b}`; } // Define the parameters of the sum function as a JSON schema @@ -24,7 +24,7 @@ const sumJSON = { }, }, required: ["a", "b"], -}; +} as const; const divideJSON = { type: "object", @@ -39,7 +39,7 @@ const divideJSON = { }, }, required: ["a", "b"], -}; +} as const; async function main() { // Create a function tool from the sum function diff --git a/examples/agent/step_wise_openai.ts b/examples/agent/step_wise_openai.ts index f5c50db78ef5b42d5e758c62040612f3ead5262e..08b20509bac64e1ba45a443942458b0c6561df85 100644 --- a/examples/agent/step_wise_openai.ts +++ b/examples/agent/step_wise_openai.ts @@ -1,13 +1,13 @@ import { FunctionTool, OpenAIAgent } from "llamaindex"; // Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; +function sumNumbers({ a, b }: { a: number; b: number }) { + return `${a + b}`; } // Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }): number { - return a / b; +function divideNumbers({ a, b }: { a: number; b: number }) { + return `${a / b}`; } // Define the parameters of the sum function as a JSON schema @@ -24,22 +24,22 @@ const sumJSON = { }, }, required: ["a", "b"], -}; +} as const; const divideJSON = { type: "object", properties: { a: { type: "number", - description: "The dividend a to divide", + description: "The dividend", }, b: { type: "number", - description: "The divisor b to divide by", + description: "The divisor", }, }, required: ["a", "b"], -}; +} as const; async function main() { // Create a function tool from the sum function diff --git a/examples/agent/step_wise_react.ts b/examples/agent/step_wise_react.ts index 81f2f87b8568cfde8ac198cc2783d203870b2795..185e528212419c8981458fc0007ff1465a9a69b1 100644 --- a/examples/agent/step_wise_react.ts +++ b/examples/agent/step_wise_react.ts @@ -1,13 +1,13 @@ import { FunctionTool, ReActAgent } from "llamaindex"; // Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; +function sumNumbers({ a, b }: { a: number; b: number }) { + return `${a + b}`; } // Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }): number { - return a / b; +function divideNumbers({ a, b }: { a: number; b: number }) { + return `${a / b}`; } // Define the parameters of the sum function as a JSON schema @@ -24,7 +24,7 @@ const sumJSON = { }, }, required: ["a", "b"], -}; +} as const; const divideJSON = { type: "object", @@ -39,7 +39,7 @@ const divideJSON = { }, }, required: ["a", "b"], -}; +} as const; async function main() { // Create a function tool from the sum function diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts index 3d942ad3ddc5253c276f2b503665e2a538aa6b6b..0a1d89dfb42273612377a58cf5e0b2cee2274c99 100644 --- a/examples/agent/stream_openai_agent.ts +++ b/examples/agent/stream_openai_agent.ts @@ -1,13 +1,13 @@ import { FunctionTool, OpenAIAgent } from "llamaindex"; // Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; +function sumNumbers({ a, b }: { a: number; b: number }) { + return `${a + b}`; } // Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }): number { - return a / b; +function divideNumbers({ a, b }: { a: number; b: number }) { + return `${a / b}`; } // Define the parameters of the sum function as a JSON schema @@ -24,7 +24,7 @@ const sumJSON = { }, }, required: ["a", "b"], -}; +} as const; const divideJSON = { type: "object", @@ -39,18 +39,18 @@ const divideJSON = { }, }, required: ["a", "b"], -}; +} as const; async function main() { // Create a function tool from the sum function - const functionTool = new FunctionTool(sumNumbers, { + const functionTool = FunctionTool.from(sumNumbers, { name: "sumNumbers", description: "Use this function to sum two numbers", parameters: sumJSON, }); // Create a function tool from the divide function - const functionTool2 = new FunctionTool(divideNumbers, { + const functionTool2 = FunctionTool.from(divideNumbers, { name: "divideNumbers", description: "Use this function to divide two numbers", parameters: divideJSON, diff --git a/examples/toolsStream.ts b/examples/toolsStream.ts index 6aced425b094fe525828d8c60ca6f8b573597673..3b1e86619d2525152c9f5de09de922cd3ee250cf 100644 --- a/examples/toolsStream.ts +++ b/examples/toolsStream.ts @@ -2,7 +2,6 @@ import { OpenAI } from "llamaindex"; async function main() { const llm = new OpenAI({ model: "gpt-4-turbo" }); - const args: Parameters<typeof llm.chat>[0] = { additionalChatOptions: { tool_choice: "auto", diff --git a/packages/core/e2e/node/openai.e2e.ts b/packages/core/e2e/node/openai.e2e.ts index 883ab0c5caf38b13ee0b988bdc7055f30245a58e..452d27be6e1c4d07f1f06bc4ad5496b5864cb72d 100644 --- a/packages/core/e2e/node/openai.e2e.ts +++ b/packages/core/e2e/node/openai.e2e.ts @@ -1,7 +1,7 @@ -/* eslint-disable @typescript-eslint/no-floating-promises */ import { consola } from "consola"; import { Document, + FunctionTool, OpenAI, OpenAIAgent, QueryEngineTool, @@ -22,7 +22,7 @@ beforeEach(async () => { llm = Settings.llm; }); -test("llm", async (t) => { +await test("llm", async (t) => { await mockLLMEvent(t, "llm"); await t.test("llm.chat", async () => { const response = await llm.chat({ @@ -54,7 +54,7 @@ test("llm", async (t) => { }); }); -test("gpt-4-turbo", async (t) => { +await test("gpt-4-turbo", async (t) => { const llm = new OpenAI({ model: "gpt-4-turbo" }); Settings.llm = llm; await mockLLMEvent(t, "gpt-4-turbo"); @@ -89,7 +89,7 @@ test("gpt-4-turbo", async (t) => { }); }); -test("agent", async (t) => { +await test("agent", async (t) => { await mockLLMEvent(t, "agent"); await t.test("chat", async () => { const agent = new OpenAIAgent({ @@ -119,9 +119,80 @@ test("agent", async (t) => { ok(typeof result.response === "string"); ok(result.response.includes("35")); }); + + await t.test("async function", async () => { + const uniqueId = "123456789"; + const showUniqueId = FunctionTool.from<{ + firstName: string; + lastName: string; + }>( + async ({ firstName, lastName }) => { + ok(typeof firstName === "string"); + ok(typeof lastName === "string"); + const fullName = firstName + lastName; + ok(fullName.toLowerCase().includes("alex")); + ok(fullName.toLowerCase().includes("yang")); + return uniqueId; + }, + { + name: "unique_id", + description: "show user unique id", + parameters: { + type: "object", + properties: { + firstName: { type: "string" }, + lastName: { type: "string" }, + }, + required: ["firstName", "lastName"], + }, + }, + ); + const agent = new OpenAIAgent({ + tools: [showUniqueId], + }); + const { response } = await agent.chat({ + message: "My name is Alex Yang. What is my unique id?", + }); + consola.debug("response:", response); + ok(response.includes(uniqueId)); + }); + + await t.test("sum numbers", async () => { + function sumNumbers({ a, b }: { a: number; b: number }): string { + return `${a + b}`; + } + const sumFunctionTool = new FunctionTool(sumNumbers, { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], + }, + }); + + const openaiAgent = new OpenAIAgent({ + tools: [sumFunctionTool], + }); + + const response = await openaiAgent.chat({ + message: "how much is 1 + 1?", + }); + + ok(response.response.includes("2")); + }); }); -test("queryEngine", async (t) => { +await test("queryEngine", async (t) => { await mockLLMEvent(t, "queryEngine_subquestion"); await t.test("subquestion", async () => { const document = new Document({ diff --git a/packages/core/e2e/node/snapshot/agent.snap b/packages/core/e2e/node/snapshot/agent.snap index 4571ee4965514b005f6eb5806d6da90d00e4cf93..7ef3898f043090025ab04be9f4f79a3b893b4dc1 100644 --- a/packages/core/e2e/node/snapshot/agent.snap +++ b/packages/core/e2e/node/snapshot/agent.snap @@ -41,6 +41,90 @@ } } ] + }, + { + "id": "HIDDEN", + "messages": [ + { + "content": "My name is Alex Yang. What is my unique id?", + "role": "user" + } + ] + }, + { + "id": "HIDDEN", + "messages": [ + { + "content": "My name is Alex Yang. What is my unique id?", + "role": "user" + }, + { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "unique_id", + "arguments": "{\"firstName\":\"Alex\",\"lastName\":\"Yang\"}" + } + } + ] + } + }, + { + "content": "123456789", + "role": "tool", + "options": { + "name": "unique_id", + "tool_call_id": "HIDDEN" + } + } + ] + }, + { + "id": "HIDDEN", + "messages": [ + { + "content": "how much is 1 + 1?", + "role": "user" + } + ] + }, + { + "id": "HIDDEN", + "messages": [ + { + "content": "how much is 1 + 1?", + "role": "user" + }, + { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "sumNumbers", + "arguments": "{\"a\":1,\"b\":1}" + } + } + ] + } + }, + { + "content": "2", + "role": "tool", + "options": { + "name": "sumNumbers", + "tool_call_id": "HIDDEN" + } + } + ] } ], "llmEventEnd": [ @@ -130,6 +214,180 @@ "options": {} } } + }, + { + "id": "HIDDEN", + "response": { + "raw": { + "id": "HIDDEN", + "object": "chat.completion", + "created": 114514, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "unique_id", + "arguments": "{\"firstName\":\"Alex\",\"lastName\":\"Yang\"}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 59, + "completion_tokens": 18, + "total_tokens": 77 + }, + "system_fingerprint": "HIDDEN" + }, + "message": { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "unique_id", + "arguments": "{\"firstName\":\"Alex\",\"lastName\":\"Yang\"}" + } + } + ] + } + } + } + }, + { + "id": "HIDDEN", + "response": { + "raw": { + "id": "HIDDEN", + "object": "chat.completion", + "created": 114514, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Your unique id is 123456789." + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 88, + "completion_tokens": 10, + "total_tokens": 98 + }, + "system_fingerprint": "HIDDEN" + }, + "message": { + "content": "Your unique id is 123456789.", + "role": "assistant", + "options": {} + } + } + }, + { + "id": "HIDDEN", + "response": { + "raw": { + "id": "HIDDEN", + "object": "chat.completion", + "created": 114514, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "sumNumbers", + "arguments": "{\"a\":1,\"b\":1}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 70, + "completion_tokens": 18, + "total_tokens": 88 + }, + "system_fingerprint": "HIDDEN" + }, + "message": { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "HIDDEN", + "type": "function", + "function": { + "name": "sumNumbers", + "arguments": "{\"a\":1,\"b\":1}" + } + } + ] + } + } + } + }, + { + "id": "HIDDEN", + "response": { + "raw": { + "id": "HIDDEN", + "object": "chat.completion", + "created": 114514, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "1 + 1 is equal to 2." + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 97, + "completion_tokens": 11, + "total_tokens": 108 + }, + "system_fingerprint": "HIDDEN" + }, + "message": { + "content": "1 + 1 is equal to 2.", + "role": "assistant", + "options": {} + } + } } ], "llmEventStream": [] diff --git a/packages/core/e2e/node/snapshot/gpt-4-turbo.snap b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap index 960c40f77176da8f254f268db6fb126033088e4b..5a33bd8d1d191884a02cdf17b0502b29a0080b2c 100644 --- a/packages/core/e2e/node/snapshot/gpt-4-turbo.snap +++ b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap @@ -111,7 +111,7 @@ "index": 0, "message": { "role": "assistant", - "content": "The weather in San Jose is currently 45 degrees and sunny." + "content": "The weather in San Jose is 45 degrees and sunny." }, "logprobs": null, "finish_reason": "stop" @@ -119,13 +119,13 @@ ], "usage": { "prompt_tokens": 78, - "completion_tokens": 14, - "total_tokens": 92 + "completion_tokens": 13, + "total_tokens": 91 }, "system_fingerprint": "HIDDEN" }, "message": { - "content": "The weather in San Jose is currently 45 degrees and sunny.", + "content": "The weather in San Jose is 45 degrees and sunny.", "role": "assistant", "options": {} } diff --git a/packages/core/package.json b/packages/core/package.json index ae358f0e0856d323e44beae47f92bab06913eca1..31b0fca854359285aff3622805a61fd05db0d3ce 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -21,6 +21,7 @@ "@types/pg": "^8.11.5", "@xenova/transformers": "^2.16.1", "@zilliz/milvus2-sdk-node": "^2.3.5", + "ajv": "^8.12.0", "assemblyai": "^4.3.4", "chromadb": "~1.7.3", "cohere-ai": "^7.9.3", diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 8a229ffac264cdf3624ec17e532537ad94cf02ea..326ffea6cf88c9aafe0678fe8debd9960963e5ae 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -46,7 +46,7 @@ async function callFunction( // Call tool // Use default error message - const output = await callToolWithErrorHandling(tool, argumentDict, null); + const output = await callToolWithErrorHandling(tool, argumentDict); if (Settings.debug) { console.log(`Got output ${output}`); diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index 96af95787e470838a71dad7fc43928cc42ac5baa..7987d33c087d390dc4325296c0005ab3fc6e326d 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -194,7 +194,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { const tool = toolsDict[actionReasoningStep.action]; - const toolOutput = await tool?.call?.(actionReasoningStep.actionInput); + const toolOutput = await tool.call!(actionReasoningStep.actionInput); task.extraState.sources.push( new ToolOutput( diff --git a/packages/core/src/tools/QueryEngineTool.ts b/packages/core/src/tools/QueryEngineTool.ts index 60fa03755880a6d47fdc8829924b41b5f6e8933c..9af842c2957c33664c0c0e8e6d8f84d42a3c8087 100644 --- a/packages/core/src/tools/QueryEngineTool.ts +++ b/packages/core/src/tools/QueryEngineTool.ts @@ -1,18 +1,11 @@ +import type { JSONSchemaType } from "ajv"; import type { BaseQueryEngine, BaseTool, ToolMetadata } from "../types.js"; -export type QueryEngineToolParams = { - queryEngine: BaseQueryEngine; - metadata: ToolMetadata; -}; - -type QueryEngineCallParams = { - query: string; -}; - const DEFAULT_NAME = "query_engine_tool"; const DEFAULT_DESCRIPTION = "Useful for running a natural language query against a knowledge base and get back a natural language response."; -const DEFAULT_PARAMETERS = { + +const DEFAULT_PARAMETERS: JSONSchemaType<QueryEngineParam> = { type: "object", properties: { query: { @@ -23,9 +16,18 @@ const DEFAULT_PARAMETERS = { required: ["query"], }; -export class QueryEngineTool implements BaseTool { +export type QueryEngineToolParams = { + queryEngine: BaseQueryEngine; + metadata: ToolMetadata<JSONSchemaType<QueryEngineParam>>; +}; + +export type QueryEngineParam = { + query: string; +}; + +export class QueryEngineTool implements BaseTool<QueryEngineParam> { private queryEngine: BaseQueryEngine; - metadata: ToolMetadata; + metadata: ToolMetadata<JSONSchemaType<QueryEngineParam>>; constructor({ queryEngine, metadata }: QueryEngineToolParams) { this.queryEngine = queryEngine; @@ -36,18 +38,8 @@ export class QueryEngineTool implements BaseTool { }; } - async call(...args: QueryEngineCallParams[]): Promise<any> { - let queryStr: string; - - if (args && args.length > 0) { - queryStr = String(args[0].query); - } else { - throw new Error( - "Cannot call query engine without specifying `input` parameter.", - ); - } - - const response = await this.queryEngine.query({ query: queryStr }); + async call({ query }: QueryEngineParam) { + const response = await this.queryEngine.query({ query }); return response.response; } diff --git a/packages/core/src/tools/ToolFactory.ts b/packages/core/src/tools/ToolFactory.ts deleted file mode 100644 index d40b0d9500b3f802d39208f58e95737c23cdd0c0..0000000000000000000000000000000000000000 --- a/packages/core/src/tools/ToolFactory.ts +++ /dev/null @@ -1,33 +0,0 @@ -import type { BaseTool } from "../types.js"; -import { WikipediaTool } from "./WikipediaTool.js"; - -enum Tools { - Wikipedia = "wikipedia.WikipediaToolSpec", -} - -type ToolConfig = { [key in Tools]: Record<string, any> }; - -export class ToolFactory { - private static async createTool( - key: Tools, - options: Record<string, any>, - ): Promise<BaseTool> { - if (key === Tools.Wikipedia) { - const tool = new WikipediaTool(); - return tool; - } - - throw new Error( - `Sorry! Tool ${key} is not supported yet. Options: ${options}`, - ); - } - - public static async createTools(config: ToolConfig): Promise<BaseTool[]> { - const tools: BaseTool[] = []; - for (const [key, value] of Object.entries(config as ToolConfig)) { - const tool = await ToolFactory.createTool(key as Tools, value); - tools.push(tool); - } - return tools; - } -} diff --git a/packages/core/src/tools/WikipediaTool.ts b/packages/core/src/tools/WikipediaTool.ts index 2764f2726458168dd3db1dc5f4c57795c812f3ea..6a1ff6904ea25430e7548fea0937bef79bf1de08 100644 --- a/packages/core/src/tools/WikipediaTool.ts +++ b/packages/core/src/tools/WikipediaTool.ts @@ -1,16 +1,17 @@ +import type { JSONSchemaType } from "ajv"; import { default as wiki } from "wikipedia"; import type { BaseTool, ToolMetadata } from "../types.js"; -export type WikipediaToolParams = { - metadata?: ToolMetadata; -}; - -type WikipediaCallParams = { +type WikipediaParameter = { query: string; lang?: string; }; -const DEFAULT_META_DATA: ToolMetadata = { +export type WikipediaToolParams = { + metadata?: ToolMetadata<JSONSchemaType<WikipediaParameter>>; +}; + +const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WikipediaParameter>> = { name: "wikipedia_tool", description: "A tool that uses a query engine to search Wikipedia.", parameters: { @@ -20,14 +21,19 @@ const DEFAULT_META_DATA: ToolMetadata = { type: "string", description: "The query to search for", }, + lang: { + type: "string", + description: "The language to search in", + nullable: true, + }, }, required: ["query"], }, }; -export class WikipediaTool implements BaseTool { +export class WikipediaTool implements BaseTool<WikipediaParameter> { private readonly DEFAULT_LANG = "en"; - metadata: ToolMetadata; + metadata: ToolMetadata<JSONSchemaType<WikipediaParameter>>; constructor(params?: WikipediaToolParams) { this.metadata = params?.metadata || DEFAULT_META_DATA; @@ -46,7 +52,7 @@ export class WikipediaTool implements BaseTool { async call({ query, lang = this.DEFAULT_LANG, - }: WikipediaCallParams): Promise<string> { + }: WikipediaParameter): Promise<string> { const searchResult = await wiki.default.search(query); if (searchResult.results.length === 0) return "No search results."; return await this.loadData(searchResult.results[0].title, lang); diff --git a/packages/core/src/tools/functionTool.ts b/packages/core/src/tools/functionTool.ts index d59b07589ec8bf8d20491250ae7b6159b4cb00cc..cbd8a28e3bfdff4c9f0853b1050a31b90c75d363 100644 --- a/packages/core/src/tools/functionTool.ts +++ b/packages/core/src/tools/functionTool.ts @@ -1,32 +1,30 @@ +import type { JSONSchemaType } from "ajv"; import type { BaseTool, ToolMetadata } from "../types.js"; -type Metadata = { - name: string; - description: string; - parameters: ToolMetadata["parameters"]; -}; +export class FunctionTool<T, R extends string | Promise<string>> + implements BaseTool<T> +{ + constructor( + private readonly _fn: (input: T) => R, + private readonly _metadata: ToolMetadata<JSONSchemaType<T>>, + ) {} -export class FunctionTool<T = any> implements BaseTool { - private _fn: (...args: any[]) => any; - private _metadata: ToolMetadata; - - constructor(fn: (...args: any[]) => any, metadata: Metadata) { - this._fn = fn; - this._metadata = metadata as ToolMetadata; - } - - static fromDefaults<T = any>( - fn: (...args: any[]) => any, - metadata?: Metadata, - ): FunctionTool<T> { - return new FunctionTool(fn, metadata!); + static from<T>( + fn: (input: T) => string | Promise<string>, + schema: ToolMetadata<JSONSchemaType<T>>, + ): FunctionTool<T, string | Promise<string>>; + static from<T, R extends string | Promise<string>>( + fn: (input: T) => R, + schema: ToolMetadata<JSONSchemaType<T>>, + ): FunctionTool<T, R> { + return new FunctionTool(fn, schema); } - get metadata(): ToolMetadata { - return this._metadata; + get metadata(): BaseTool<T>["metadata"] { + return this._metadata as BaseTool<T>["metadata"]; } - async call(...args: any[]): Promise<any> { - return this._fn(...args); + call(input: T) { + return this._fn(input); } } diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts index c9d38eef9943bb3bdfe0fc45d2b977492ce2a8a4..2b52c32edf8fbc07dd31882a66ace92095400d61 100644 --- a/packages/core/src/tools/index.ts +++ b/packages/core/src/tools/index.ts @@ -1,5 +1,4 @@ export * from "./QueryEngineTool.js"; -export * from "./ToolFactory.js"; export * from "./WikipediaTool.js"; export * from "./functionTool.js"; export * from "./types.js"; diff --git a/packages/core/src/tools/utils.ts b/packages/core/src/tools/utils.ts index aaba6b01ac3a7963227be599b7b8f7a4cce8035d..3c19c20a4dd19102c4925dd1f2e73ee05c8caf4a 100644 --- a/packages/core/src/tools/utils.ts +++ b/packages/core/src/tools/utils.ts @@ -1,30 +1,24 @@ import type { BaseTool } from "../types.js"; import { ToolOutput } from "./types.js"; -/** - * Call tool with error handling. - * @param tool: tool - * @param inputDict: input dict - * @param errorMessage: error message - * @param raiseError: raise error - * @returns: tool output - */ export async function callToolWithErrorHandling( tool: BaseTool, inputDict: { [key: string]: any }, - errorMessage: string | null = null, - raiseError: boolean = false, ): Promise<ToolOutput> { + if (!tool.call) { + return new ToolOutput( + "Error: Tool does not have a call function.", + tool.metadata.name, + { kwargs: inputDict }, + null, + ); + } try { - const value = await tool.call?.(inputDict); + const value = await tool.call(inputDict); return new ToolOutput(value, tool.metadata.name, inputDict, value); } catch (e) { - if (raiseError) { - throw e; - } - errorMessage = errorMessage || `Error: ${e}`; return new ToolOutput( - errorMessage, + `Error: ${e}`, tool.metadata.name, { kwargs: inputDict }, e, diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 6adc075448d14cb4a077752cbdeefb87f5fd218d..f1b0073e3ad32eeb73de3427b98d7fea47afd44b 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -1,6 +1,7 @@ /** * Top level types to avoid circular dependencies */ +import { type JSONSchemaType } from "ajv"; import type { Response } from "./Response.js"; /** @@ -30,14 +31,46 @@ export interface BaseQueryEngine { query(params: QueryEngineParamsNonStreaming): Promise<Response>; } +type Known = + | { [key: string]: Known } + | [Known, ...Known[]] + | Known[] + | number + | string + | boolean + | null; + +export type ToolMetadata< + Parameters extends Record<string, unknown> = Record<string, unknown>, +> = { + description: string; + name: string; + /** + * OpenAI uses JSON Schema to describe the parameters that a tool can take. + * @link https://json-schema.org/understanding-json-schema + */ + parameters?: Parameters; +}; + /** * Simple Tool interface. Likely to change. */ -export interface BaseTool { - call?: (...args: any[]) => any; - metadata: ToolMetadata; +export interface BaseTool<Input = any> { + /** + * This could be undefined if the implementation is not provided, + * which might be the case when communicating with a llm. + * + * @return string - the output of the tool, should be string in any case for LLM input. + */ + call?: (input: Input) => string | Promise<string>; + metadata: // if user input any, we cannot check the schema + Input extends Known ? ToolMetadata<JSONSchemaType<Input>> : ToolMetadata; } +export type ToolWithCall<Input = unknown> = Omit<BaseTool<Input>, "call"> & { + call: NonNullable<Pick<BaseTool<Input>, "call">["call"]>; +}; + /** * An OutputParser is used to extract structured data from the raw output of the LLM. */ @@ -55,19 +88,6 @@ export interface StructuredOutput<T> { parsedOutput: T; } -export type ToolParameters = { - type: string | "object"; - properties: Record<string, { type: string; description?: string }>; - required?: string[]; -}; - -export interface ToolMetadata { - description: string; - name: string; - parameters?: ToolParameters; - argsKwargs?: Record<string, any>; -} - export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">; export class QueryBundle { diff --git a/packages/core/tests/agent/OpenAIAgent.test.ts b/packages/core/tests/agent/OpenAIAgent.test.ts deleted file mode 100644 index 0189e0158ff27a5e2fb3fbea586213a6dcadb9cc..0000000000000000000000000000000000000000 --- a/packages/core/tests/agent/OpenAIAgent.test.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { Settings } from "llamaindex"; -import { OpenAIAgent } from "llamaindex/agent/index"; -import { OpenAI } from "llamaindex/llm/index"; -import { FunctionTool } from "llamaindex/tools/index"; -import { beforeEach, describe, expect, it } from "vitest"; -import { mockLlmToolCallGeneration } from "../utility/mockOpenAI.js"; - -// Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; -} - -const sumJSON = { - type: "object", - properties: { - a: { - type: "number", - description: "The first number", - }, - b: { - type: "number", - description: "The second number", - }, - }, - required: ["a", "b"], -}; - -describe("OpenAIAgent", () => { - let openaiAgent: OpenAIAgent; - - beforeEach(() => { - const languageModel = new OpenAI({ - model: "gpt-3.5-turbo", - }); - - Settings.llm = languageModel; - - mockLlmToolCallGeneration({ - languageModel, - }); - - const sumFunctionTool = new FunctionTool(sumNumbers, { - name: "sumNumbers", - description: "Use this function to sum two numbers", - parameters: sumJSON, - }); - - openaiAgent = new OpenAIAgent({ - tools: [sumFunctionTool], - llm: languageModel, - }); - }); - - it("should be able to chat with agent", async () => { - const response = await openaiAgent.chat({ - message: "how much is 1 + 1?", - }); - - expect(String(response)).toEqual("The sum is 2"); - }); -}); diff --git a/packages/core/tests/objects/ObjectIndex.test.ts b/packages/core/tests/objects/ObjectIndex.test.ts index f71fd2ef2dc8e3084a84622ed24d03865ae7d2c9..a68f44c440fd13a4b5f5cbb8d1c0bac1a4cba21c 100644 --- a/packages/core/tests/objects/ObjectIndex.test.ts +++ b/packages/core/tests/objects/ObjectIndex.test.ts @@ -17,7 +17,7 @@ describe("ObjectIndex", () => { }); test("test_object_with_tools", async () => { - const tool1 = new FunctionTool((x: any) => x, { + const tool1 = new FunctionTool(({ x }: { x: string }) => x, { name: "test_tool", description: "test tool", parameters: { @@ -27,10 +27,11 @@ describe("ObjectIndex", () => { type: "string", }, }, + required: ["x"], }, }); - const tool2 = new FunctionTool((x: any) => x, { + const tool2 = new FunctionTool(({ x }: { x: string }) => x, { name: "test_tool_2", description: "test tool 2", parameters: { @@ -40,6 +41,7 @@ describe("ObjectIndex", () => { type: "string", }, }, + required: ["x"], }, }); @@ -62,7 +64,7 @@ describe("ObjectIndex", () => { }); test("add a new object", async () => { - const tool1 = new FunctionTool((x: any) => x, { + const tool1 = new FunctionTool(({ x }: { x: string }) => x, { name: "test_tool", description: "test tool", parameters: { @@ -72,10 +74,11 @@ describe("ObjectIndex", () => { type: "string", }, }, + required: ["x"], }, }); - const tool2 = new FunctionTool((x: any) => x, { + const tool2 = new FunctionTool(({ x }: { x: string }) => x, { name: "test_tool_2", description: "test tool 2", parameters: { @@ -85,6 +88,7 @@ describe("ObjectIndex", () => { type: "string", }, }, + required: ["x"], }, }); diff --git a/packages/core/tests/tools/Tools.test.ts b/packages/core/tests/tools/Tools.test.ts deleted file mode 100644 index e2edfbe610322d89ae32a48ab5126831e7748150..0000000000000000000000000000000000000000 --- a/packages/core/tests/tools/Tools.test.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { FunctionTool, ToolOutput } from "llamaindex/tools/index"; -import { callToolWithErrorHandling } from "llamaindex/tools/utils"; -import { describe, expect, it } from "vitest"; - -function sumNumbers({ a, b }: { a: number; b: number }): number { - return a + b; -} - -const sumJSON = { - type: "object", - properties: { - a: { - type: "number", - description: "The first number", - }, - b: { - type: "number", - description: "The second number", - }, - }, - required: ["a", "b"], -}; - -describe("Tools", () => { - it("should be able to call a tool with a common JSON", async () => { - const tool = new FunctionTool(sumNumbers, { - name: "sumNumbers", - description: "Use this function to sum two numbers", - parameters: sumJSON, - }); - - const response = await callToolWithErrorHandling(tool, { - a: 1, - b: 2, - }); - - expect(response).toEqual( - new ToolOutput( - response.content, - tool.metadata.name, - { a: 1, b: 2 }, - response.content, - ), - ); - }); -}); diff --git a/packages/edge/package.json b/packages/edge/package.json index e0679df58e1e92ae4a2a0047eaae465e1abedf6b..c5331a4c3f8e4ea1cf072f89806dbfdc45a2a313 100644 --- a/packages/edge/package.json +++ b/packages/edge/package.json @@ -20,6 +20,7 @@ "@types/pg": "^8.11.5", "@xenova/transformers": "^2.16.1", "@zilliz/milvus2-sdk-node": "^2.3.5", + "ajv": "^8.12.0", "assemblyai": "^4.3.4", "chromadb": "~1.7.3", "cohere-ai": "^7.9.3", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 891ebda9c2573f041314a6362748dc3aed4377d4..8d46d13bdf15a208006f6f0dc48b3c25015e2fea 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -223,6 +223,9 @@ importers: '@zilliz/milvus2-sdk-node': specifier: ^2.3.5 version: 2.3.5 + ajv: + specifier: ^8.12.0 + version: 8.12.0 assemblyai: specifier: ^4.3.4 version: 4.3.4 @@ -383,6 +386,9 @@ importers: '@zilliz/milvus2-sdk-node': specifier: ^2.3.5 version: 2.3.5 + ajv: + specifier: ^8.12.0 + version: 8.12.0 assemblyai: specifier: ^4.3.4 version: 4.3.4