From a285f8ba3abfa70eec6fa26936d0d60981fe5d32 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Thu, 11 Apr 2024 21:26:14 -0500 Subject: [PATCH] feat: improve `ToolsFactory` type (#713) --- packages/core/src/tools/ToolsFactory.ts | 43 +++++++++++++++++++++++++ packages/core/tests/tools.test.ts | 32 ++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 packages/core/src/tools/ToolsFactory.ts create mode 100644 packages/core/tests/tools.test.ts diff --git a/packages/core/src/tools/ToolsFactory.ts b/packages/core/src/tools/ToolsFactory.ts new file mode 100644 index 000000000..eb650463f --- /dev/null +++ b/packages/core/src/tools/ToolsFactory.ts @@ -0,0 +1,43 @@ +import { WikipediaTool } from "./WikipediaTool.js"; + +export namespace ToolsFactory { + type ToolsMap = { + [Tools.Wikipedia]: typeof WikipediaTool; + }; + + export enum Tools { + Wikipedia = "wikipedia.WikipediaToolSpec", + } + + export async function createTool<Tool extends Tools>( + key: Tool, + ...params: ConstructorParameters<ToolsMap[Tool]> + ): Promise<InstanceType<ToolsMap[Tool]>> { + if (key === Tools.Wikipedia) { + return new WikipediaTool(...params) as InstanceType<ToolsMap[Tool]>; + } + + throw new Error( + `Sorry! Tool ${key} is not supported yet. Options: ${params}`, + ); + } + + export async function createTools<const Tool extends Tools>(record: { + [key in Tool]: ConstructorParameters<ToolsMap[Tool]>[1] extends any // backward compatibility for `create-llama` script // if parameters are an array, use them as is + ? ConstructorParameters<ToolsMap[Tool]>[0] + : ConstructorParameters<ToolsMap[Tool]>; + }): Promise<InstanceType<ToolsMap[Tool]>[]> { + const tools: InstanceType<ToolsMap[Tool]>[] = []; + for (const key in record) { + const params = record[key]; + tools.push( + await createTool( + key, + // @ts-expect-error + Array.isArray(params) ? params : [params], + ), + ); + } + return tools; + } +} diff --git a/packages/core/tests/tools.test.ts b/packages/core/tests/tools.test.ts new file mode 100644 index 000000000..a22c14595 --- /dev/null +++ b/packages/core/tests/tools.test.ts @@ -0,0 +1,32 @@ +import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; +import { WikipediaTool } from "llamaindex/tools/WikipediaTool"; +import { assertType, describe, test } from "vitest"; + +describe("ToolsFactory", async () => { + test("createTool", async () => { + await ToolsFactory.createTool(ToolsFactory.Tools.Wikipedia, { + metadata: { + name: "wikipedia_tool", + description: "A tool that uses a query engine to search Wikipedia.", + }, + }); + }); + test("createTools", async () => { + await ToolsFactory.createTools({ + [ToolsFactory.Tools.Wikipedia]: { + metadata: { + name: "wikipedia_tool", + description: "A tool that uses a query engine to search Wikipedia.", + }, + }, + }); + }); + test("type", () => { + assertType< + ( + key: ToolsFactory.Tools.Wikipedia, + params: ConstructorParameters<typeof WikipediaTool>[0], + ) => Promise<WikipediaTool> + >(ToolsFactory.createTool<ToolsFactory.Tools.Wikipedia>); + }); +}); -- GitLab