From 723b41c23c6a0535a1e091d3bc0e9fc61ea64dd0 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Fri, 18 Oct 2024 09:45:01 -0700
Subject: [PATCH] refactor: move tools into core module (#1316)

---
 packages/core/package.json                    | 33 +++++++---
 packages/core/src/tools/function-tool.ts      | 62 +++++++++++++++++++
 packages/core/src/tools/index.ts              |  1 +
 packages/core/tests/tools.test.ts             | 35 +++++++++++
 packages/core/tools/package.json              |  8 +++
 packages/core/vector-store/package.json       |  8 +++
 packages/llamaindex/src/tools/functionTool.ts | 31 ----------
 packages/llamaindex/src/tools/index.ts        |  2 +-
 pnpm-lock.yaml                                | 15 ++++-
 9 files changed, 154 insertions(+), 41 deletions(-)
 create mode 100644 packages/core/src/tools/function-tool.ts
 create mode 100644 packages/core/src/tools/index.ts
 create mode 100644 packages/core/tests/tools.test.ts
 create mode 100644 packages/core/tools/package.json
 create mode 100644 packages/core/vector-store/package.json
 delete mode 100644 packages/llamaindex/src/tools/functionTool.ts

diff --git a/packages/core/package.json b/packages/core/package.json
index cf9dd048d..13614afcd 100644
--- a/packages/core/package.json
+++ b/packages/core/package.json
@@ -258,16 +258,30 @@
     },
     "./vector-store": {
       "require": {
-        "types": "./dist/vector-store/index.d.cts",
-        "default": "./dist/vector-store/index.cjs"
+        "types": "./vector-store/dist/index.d.cts",
+        "default": "./vector-store/dist/index.cjs"
       },
       "import": {
-        "types": "./dist/vector-store/index.d.ts",
-        "default": "./dist/vector-store/index.js"
+        "types": "./vector-store/dist/index.d.ts",
+        "default": "./vector-store/dist/index.js"
       },
       "default": {
-        "types": "./dist/vector-store/index.d.ts",
-        "default": "./dist/vector-store/index.js"
+        "types": "./vector-store/dist/index.d.ts",
+        "default": "./vector-store/dist/index.js"
+      }
+    },
+    "./tools": {
+      "require": {
+        "types": "./tools/dist/index.d.cts",
+        "default": "./tools/dist/index.cjs"
+      },
+      "import": {
+        "types": "./tools/dist/index.d.ts",
+        "default": "./tools/dist/index.js"
+      },
+      "default": {
+        "types": "./tools/dist/index.d.ts",
+        "default": "./tools/dist/index.js"
       }
     }
   },
@@ -289,7 +303,9 @@
     "./storage",
     "./response-synthesizers",
     "./chat-engine",
-    "./retriever"
+    "./retriever",
+    "./vector-store",
+    "./tools"
   ],
   "scripts": {
     "dev": "bunchee --watch",
@@ -312,6 +328,7 @@
     "@llamaindex/env": "workspace:*",
     "@types/node": "^22.5.1",
     "magic-bytes.js": "^1.10.0",
-    "zod": "^3.23.8"
+    "zod": "^3.23.8",
+    "zod-to-json-schema": "^3.23.3"
   }
 }
diff --git a/packages/core/src/tools/function-tool.ts b/packages/core/src/tools/function-tool.ts
new file mode 100644
index 000000000..25fbd452c
--- /dev/null
+++ b/packages/core/src/tools/function-tool.ts
@@ -0,0 +1,62 @@
+import type { JSONSchemaType } from "ajv";
+import { z } from "zod";
+import { zodToJsonSchema } from "zod-to-json-schema";
+import type { JSONValue } from "../global";
+import type { BaseTool, ToolMetadata } from "../llms";
+
+const kOriginalFn = Symbol("originalFn");
+
+export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
+  implements BaseTool<T>
+{
+  [kOriginalFn]?: (input: T) => R;
+
+  #fn: (input: T) => R;
+  #metadata: ToolMetadata<JSONSchemaType<T>>;
+  // todo: for the future, we can use zod to validate the input parameters
+  #zodType: z.ZodType<T> | null = null;
+  constructor(
+    fn: (input: T) => R,
+    metadata: ToolMetadata<JSONSchemaType<T>>,
+    zodType?: z.ZodType<T>,
+  ) {
+    this.#fn = fn;
+    this.#metadata = metadata;
+    if (zodType) {
+      this.#zodType = zodType;
+    }
+  }
+
+  static from<T>(
+    fn: (input: T) => JSONValue | Promise<JSONValue>,
+    schema: ToolMetadata<JSONSchemaType<T>>,
+  ): FunctionTool<T, JSONValue | Promise<JSONValue>>;
+  static from<T, R extends z.ZodType<T>>(
+    fn: (input: T) => JSONValue | Promise<JSONValue>,
+    schema: Omit<ToolMetadata, "parameters"> & {
+      parameters: R;
+    },
+  ): FunctionTool<T, JSONValue>;
+  static from(fn: any, schema: any): any {
+    if (schema.parameter instanceof z.ZodSchema) {
+      const jsonSchema = zodToJsonSchema(schema.parameter);
+      return new FunctionTool(
+        fn,
+        {
+          ...schema,
+          parameters: jsonSchema,
+        },
+        schema.parameter,
+      );
+    }
+    return new FunctionTool(fn, schema);
+  }
+
+  get metadata(): BaseTool<T>["metadata"] {
+    return this.#metadata as BaseTool<T>["metadata"];
+  }
+
+  call(input: T) {
+    return this.#fn.call(null, input);
+  }
+}
diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts
new file mode 100644
index 000000000..dd82f38d1
--- /dev/null
+++ b/packages/core/src/tools/index.ts
@@ -0,0 +1 @@
+export { FunctionTool } from "./function-tool";
diff --git a/packages/core/tests/tools.test.ts b/packages/core/tests/tools.test.ts
new file mode 100644
index 000000000..009b1ca07
--- /dev/null
+++ b/packages/core/tests/tools.test.ts
@@ -0,0 +1,35 @@
+import { FunctionTool } from "@llamaindex/core/tools";
+import { describe, test } from "vitest";
+import { z } from "zod";
+
+describe("FunctionTool", () => {
+  test("type system", () => {
+    FunctionTool.from((input: string) => input, {
+      name: "test",
+      description: "test",
+    });
+    FunctionTool.from(({ input }: { input: string }) => input, {
+      name: "test",
+      description: "test",
+      parameters: {
+        type: "object",
+        properties: {
+          input: {
+            type: "string",
+          },
+        },
+        required: ["input"],
+      },
+    });
+    const inputSchema = z
+      .object({
+        input: z.string(),
+      })
+      .required();
+    FunctionTool.from(({ input }: { input: string }) => input, {
+      name: "test",
+      description: "test",
+      parameters: inputSchema,
+    });
+  });
+});
diff --git a/packages/core/tools/package.json b/packages/core/tools/package.json
new file mode 100644
index 000000000..2fdf125e2
--- /dev/null
+++ b/packages/core/tools/package.json
@@ -0,0 +1,8 @@
+{
+  "type": "module",
+  "main": "./dist/index.cjs",
+  "module": "./dist/index.js",
+  "types": "./dist/index.d.ts",
+  "exports": "./dist/index.js",
+  "private": true
+}
diff --git a/packages/core/vector-store/package.json b/packages/core/vector-store/package.json
new file mode 100644
index 000000000..2fdf125e2
--- /dev/null
+++ b/packages/core/vector-store/package.json
@@ -0,0 +1,8 @@
+{
+  "type": "module",
+  "main": "./dist/index.cjs",
+  "module": "./dist/index.js",
+  "types": "./dist/index.d.ts",
+  "exports": "./dist/index.js",
+  "private": true
+}
diff --git a/packages/llamaindex/src/tools/functionTool.ts b/packages/llamaindex/src/tools/functionTool.ts
deleted file mode 100644
index 332abe608..000000000
--- a/packages/llamaindex/src/tools/functionTool.ts
+++ /dev/null
@@ -1,31 +0,0 @@
-import type { JSONValue } from "@llamaindex/core/global";
-import type { BaseTool, ToolMetadata } from "@llamaindex/core/llms";
-import type { JSONSchemaType } from "ajv";
-
-export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
-  implements BaseTool<T>
-{
-  constructor(
-    private readonly _fn: (input: T) => R,
-    private readonly _metadata: ToolMetadata<JSONSchemaType<T>>,
-  ) {}
-
-  static from<T>(
-    fn: (input: T) => JSONValue | Promise<JSONValue>,
-    schema: ToolMetadata<JSONSchemaType<T>>,
-  ): FunctionTool<T, JSONValue | Promise<JSONValue>>;
-  static from<T, R extends JSONValue | Promise<JSONValue>>(
-    fn: (input: T) => R,
-    schema: ToolMetadata<JSONSchemaType<T>>,
-  ): FunctionTool<T, R> {
-    return new FunctionTool(fn, schema);
-  }
-
-  get metadata(): BaseTool<T>["metadata"] {
-    return this._metadata as BaseTool<T>["metadata"];
-  }
-
-  call(input: T) {
-    return this._fn(input);
-  }
-}
diff --git a/packages/llamaindex/src/tools/index.ts b/packages/llamaindex/src/tools/index.ts
index 82fa36ab2..31a1c2893 100644
--- a/packages/llamaindex/src/tools/index.ts
+++ b/packages/llamaindex/src/tools/index.ts
@@ -1,3 +1,3 @@
-export * from "./functionTool.js";
+export * from "@llamaindex/core/tools";
 export * from "./QueryEngineTool.js";
 export * from "./WikipediaTool.js";
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index f3f66ce67..b15748d04 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -401,6 +401,9 @@ importers:
       zod:
         specifier: ^3.23.8
         version: 3.23.8
+      zod-to-json-schema:
+        specifier: ^3.23.3
+        version: 3.23.3(zod@3.23.8)
     devDependencies:
       '@edge-runtime/vm':
         specifier: ^4.0.3
@@ -6564,6 +6567,7 @@ packages:
   eslint@8.57.0:
     resolution: {integrity: sha512-dZ6+mexnaTIbSBZWgou51U6OmzIhYM2VcNdtiTtI7qPNZm35Akpr0f6vtw3w1Kmn5PYo+tZVfh13WrhpS6oLqQ==}
     engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
+    deprecated: This version is no longer supported. Please see https://eslint.org/version-support for other options.
     hasBin: true
 
   eslint@9.10.0:
@@ -11623,6 +11627,11 @@ packages:
     peerDependencies:
       zod: ^3.23.3
 
+  zod-to-json-schema@3.23.3:
+    resolution: {integrity: sha512-TYWChTxKQbRJp5ST22o/Irt9KC5nj7CdBKYB/AosCRdj/wxEMvv4NNaj9XVUHDOIp53ZxArGhnw5HMZziPFjog==}
+    peerDependencies:
+      zod: ^3.23.3
+
   zod@3.23.8:
     resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
 
@@ -19572,7 +19581,7 @@ snapshots:
   eslint-import-resolver-node@0.3.9:
     dependencies:
       debug: 3.2.7
-      is-core-module: 2.15.1
+      is-core-module: '@nolyfill/is-core-module@1.0.39'
       resolve: 1.22.8
     transitivePeerDependencies:
       - supports-color
@@ -25943,6 +25952,10 @@ snapshots:
     dependencies:
       zod: 3.23.8
 
+  zod-to-json-schema@3.23.3(zod@3.23.8):
+    dependencies:
+      zod: 3.23.8
+
   zod@3.23.8: {}
 
   zwitch@2.0.4: {}
-- 
GitLab