From a285f8ba3abfa70eec6fa26936d0d60981fe5d32 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Thu, 11 Apr 2024 21:26:14 -0500
Subject: [PATCH] feat: improve `ToolsFactory` type (#713)

---
 packages/core/src/tools/ToolsFactory.ts | 43 +++++++++++++++++++++++++
 packages/core/tests/tools.test.ts       | 32 ++++++++++++++++++
 2 files changed, 75 insertions(+)
 create mode 100644 packages/core/src/tools/ToolsFactory.ts
 create mode 100644 packages/core/tests/tools.test.ts

diff --git a/packages/core/src/tools/ToolsFactory.ts b/packages/core/src/tools/ToolsFactory.ts
new file mode 100644
index 000000000..eb650463f
--- /dev/null
+++ b/packages/core/src/tools/ToolsFactory.ts
@@ -0,0 +1,43 @@
+import { WikipediaTool } from "./WikipediaTool.js";
+
+export namespace ToolsFactory {
+  type ToolsMap = {
+    [Tools.Wikipedia]: typeof WikipediaTool;
+  };
+
+  export enum Tools {
+    Wikipedia = "wikipedia.WikipediaToolSpec",
+  }
+
+  export async function createTool<Tool extends Tools>(
+    key: Tool,
+    ...params: ConstructorParameters<ToolsMap[Tool]>
+  ): Promise<InstanceType<ToolsMap[Tool]>> {
+    if (key === Tools.Wikipedia) {
+      return new WikipediaTool(...params) as InstanceType<ToolsMap[Tool]>;
+    }
+
+    throw new Error(
+      `Sorry! Tool ${key} is not supported yet. Options: ${params}`,
+    );
+  }
+
+  export async function createTools<const Tool extends Tools>(record: {
+    [key in Tool]: ConstructorParameters<ToolsMap[Tool]>[1] extends any // backward compatibility for `create-llama` script // if parameters are an array, use them as is
+      ? ConstructorParameters<ToolsMap[Tool]>[0]
+      : ConstructorParameters<ToolsMap[Tool]>;
+  }): Promise<InstanceType<ToolsMap[Tool]>[]> {
+    const tools: InstanceType<ToolsMap[Tool]>[] = [];
+    for (const key in record) {
+      const params = record[key];
+      tools.push(
+        await createTool(
+          key,
+          // @ts-expect-error
+          Array.isArray(params) ? params : [params],
+        ),
+      );
+    }
+    return tools;
+  }
+}
diff --git a/packages/core/tests/tools.test.ts b/packages/core/tests/tools.test.ts
new file mode 100644
index 000000000..a22c14595
--- /dev/null
+++ b/packages/core/tests/tools.test.ts
@@ -0,0 +1,32 @@
+import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
+import { WikipediaTool } from "llamaindex/tools/WikipediaTool";
+import { assertType, describe, test } from "vitest";
+
+describe("ToolsFactory", async () => {
+  test("createTool", async () => {
+    await ToolsFactory.createTool(ToolsFactory.Tools.Wikipedia, {
+      metadata: {
+        name: "wikipedia_tool",
+        description: "A tool that uses a query engine to search Wikipedia.",
+      },
+    });
+  });
+  test("createTools", async () => {
+    await ToolsFactory.createTools({
+      [ToolsFactory.Tools.Wikipedia]: {
+        metadata: {
+          name: "wikipedia_tool",
+          description: "A tool that uses a query engine to search Wikipedia.",
+        },
+      },
+    });
+  });
+  test("type", () => {
+    assertType<
+      (
+        key: ToolsFactory.Tools.Wikipedia,
+        params: ConstructorParameters<typeof WikipediaTool>[0],
+      ) => Promise<WikipediaTool>
+    >(ToolsFactory.createTool<ToolsFactory.Tools.Wikipedia>);
+  });
+});
-- 
GitLab