From 723b41c23c6a0535a1e091d3bc0e9fc61ea64dd0 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Fri, 18 Oct 2024 09:45:01 -0700 Subject: [PATCH] refactor: move tools into core module (#1316) --- packages/core/package.json | 33 +++++++--- packages/core/src/tools/function-tool.ts | 62 +++++++++++++++++++ packages/core/src/tools/index.ts | 1 + packages/core/tests/tools.test.ts | 35 +++++++++++ packages/core/tools/package.json | 8 +++ packages/core/vector-store/package.json | 8 +++ packages/llamaindex/src/tools/functionTool.ts | 31 ---------- packages/llamaindex/src/tools/index.ts | 2 +- pnpm-lock.yaml | 15 ++++- 9 files changed, 154 insertions(+), 41 deletions(-) create mode 100644 packages/core/src/tools/function-tool.ts create mode 100644 packages/core/src/tools/index.ts create mode 100644 packages/core/tests/tools.test.ts create mode 100644 packages/core/tools/package.json create mode 100644 packages/core/vector-store/package.json delete mode 100644 packages/llamaindex/src/tools/functionTool.ts diff --git a/packages/core/package.json b/packages/core/package.json index cf9dd048d..13614afcd 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -258,16 +258,30 @@ }, "./vector-store": { "require": { - "types": "./dist/vector-store/index.d.cts", - "default": "./dist/vector-store/index.cjs" + "types": "./vector-store/dist/index.d.cts", + "default": "./vector-store/dist/index.cjs" }, "import": { - "types": "./dist/vector-store/index.d.ts", - "default": "./dist/vector-store/index.js" + "types": "./vector-store/dist/index.d.ts", + "default": "./vector-store/dist/index.js" }, "default": { - "types": "./dist/vector-store/index.d.ts", - "default": "./dist/vector-store/index.js" + "types": "./vector-store/dist/index.d.ts", + "default": "./vector-store/dist/index.js" + } + }, + "./tools": { + "require": { + "types": "./tools/dist/index.d.cts", + "default": "./tools/dist/index.cjs" + }, + "import": { + "types": "./tools/dist/index.d.ts", + "default": "./tools/dist/index.js" + }, + "default": { + "types": "./tools/dist/index.d.ts", + "default": "./tools/dist/index.js" } } }, @@ -289,7 +303,9 @@ "./storage", "./response-synthesizers", "./chat-engine", - "./retriever" + "./retriever", + "./vector-store", + "./tools" ], "scripts": { "dev": "bunchee --watch", @@ -312,6 +328,7 @@ "@llamaindex/env": "workspace:*", "@types/node": "^22.5.1", "magic-bytes.js": "^1.10.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zod-to-json-schema": "^3.23.3" } } diff --git a/packages/core/src/tools/function-tool.ts b/packages/core/src/tools/function-tool.ts new file mode 100644 index 000000000..25fbd452c --- /dev/null +++ b/packages/core/src/tools/function-tool.ts @@ -0,0 +1,62 @@ +import type { JSONSchemaType } from "ajv"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import type { JSONValue } from "../global"; +import type { BaseTool, ToolMetadata } from "../llms"; + +const kOriginalFn = Symbol("originalFn"); + +export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>> + implements BaseTool<T> +{ + [kOriginalFn]?: (input: T) => R; + + #fn: (input: T) => R; + #metadata: ToolMetadata<JSONSchemaType<T>>; + // todo: for the future, we can use zod to validate the input parameters + #zodType: z.ZodType<T> | null = null; + constructor( + fn: (input: T) => R, + metadata: ToolMetadata<JSONSchemaType<T>>, + zodType?: z.ZodType<T>, + ) { + this.#fn = fn; + this.#metadata = metadata; + if (zodType) { + this.#zodType = zodType; + } + } + + static from<T>( + fn: (input: T) => JSONValue | Promise<JSONValue>, + schema: ToolMetadata<JSONSchemaType<T>>, + ): FunctionTool<T, JSONValue | Promise<JSONValue>>; + static from<T, R extends z.ZodType<T>>( + fn: (input: T) => JSONValue | Promise<JSONValue>, + schema: Omit<ToolMetadata, "parameters"> & { + parameters: R; + }, + ): FunctionTool<T, JSONValue>; + static from(fn: any, schema: any): any { + if (schema.parameter instanceof z.ZodSchema) { + const jsonSchema = zodToJsonSchema(schema.parameter); + return new FunctionTool( + fn, + { + ...schema, + parameters: jsonSchema, + }, + schema.parameter, + ); + } + return new FunctionTool(fn, schema); + } + + get metadata(): BaseTool<T>["metadata"] { + return this.#metadata as BaseTool<T>["metadata"]; + } + + call(input: T) { + return this.#fn.call(null, input); + } +} diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts new file mode 100644 index 000000000..dd82f38d1 --- /dev/null +++ b/packages/core/src/tools/index.ts @@ -0,0 +1 @@ +export { FunctionTool } from "./function-tool"; diff --git a/packages/core/tests/tools.test.ts b/packages/core/tests/tools.test.ts new file mode 100644 index 000000000..009b1ca07 --- /dev/null +++ b/packages/core/tests/tools.test.ts @@ -0,0 +1,35 @@ +import { FunctionTool } from "@llamaindex/core/tools"; +import { describe, test } from "vitest"; +import { z } from "zod"; + +describe("FunctionTool", () => { + test("type system", () => { + FunctionTool.from((input: string) => input, { + name: "test", + description: "test", + }); + FunctionTool.from(({ input }: { input: string }) => input, { + name: "test", + description: "test", + parameters: { + type: "object", + properties: { + input: { + type: "string", + }, + }, + required: ["input"], + }, + }); + const inputSchema = z + .object({ + input: z.string(), + }) + .required(); + FunctionTool.from(({ input }: { input: string }) => input, { + name: "test", + description: "test", + parameters: inputSchema, + }); + }); +}); diff --git a/packages/core/tools/package.json b/packages/core/tools/package.json new file mode 100644 index 000000000..2fdf125e2 --- /dev/null +++ b/packages/core/tools/package.json @@ -0,0 +1,8 @@ +{ + "type": "module", + "main": "./dist/index.cjs", + "module": "./dist/index.js", + "types": "./dist/index.d.ts", + "exports": "./dist/index.js", + "private": true +} diff --git a/packages/core/vector-store/package.json b/packages/core/vector-store/package.json new file mode 100644 index 000000000..2fdf125e2 --- /dev/null +++ b/packages/core/vector-store/package.json @@ -0,0 +1,8 @@ +{ + "type": "module", + "main": "./dist/index.cjs", + "module": "./dist/index.js", + "types": "./dist/index.d.ts", + "exports": "./dist/index.js", + "private": true +} diff --git a/packages/llamaindex/src/tools/functionTool.ts b/packages/llamaindex/src/tools/functionTool.ts deleted file mode 100644 index 332abe608..000000000 --- a/packages/llamaindex/src/tools/functionTool.ts +++ /dev/null @@ -1,31 +0,0 @@ -import type { JSONValue } from "@llamaindex/core/global"; -import type { BaseTool, ToolMetadata } from "@llamaindex/core/llms"; -import type { JSONSchemaType } from "ajv"; - -export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>> - implements BaseTool<T> -{ - constructor( - private readonly _fn: (input: T) => R, - private readonly _metadata: ToolMetadata<JSONSchemaType<T>>, - ) {} - - static from<T>( - fn: (input: T) => JSONValue | Promise<JSONValue>, - schema: ToolMetadata<JSONSchemaType<T>>, - ): FunctionTool<T, JSONValue | Promise<JSONValue>>; - static from<T, R extends JSONValue | Promise<JSONValue>>( - fn: (input: T) => R, - schema: ToolMetadata<JSONSchemaType<T>>, - ): FunctionTool<T, R> { - return new FunctionTool(fn, schema); - } - - get metadata(): BaseTool<T>["metadata"] { - return this._metadata as BaseTool<T>["metadata"]; - } - - call(input: T) { - return this._fn(input); - } -} diff --git a/packages/llamaindex/src/tools/index.ts b/packages/llamaindex/src/tools/index.ts index 82fa36ab2..31a1c2893 100644 --- a/packages/llamaindex/src/tools/index.ts +++ b/packages/llamaindex/src/tools/index.ts @@ -1,3 +1,3 @@ -export * from "./functionTool.js"; +export * from "@llamaindex/core/tools"; export * from "./QueryEngineTool.js"; export * from "./WikipediaTool.js"; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index f3f66ce67..b15748d04 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -401,6 +401,9 @@ importers: zod: specifier: ^3.23.8 version: 3.23.8 + zod-to-json-schema: + specifier: ^3.23.3 + version: 3.23.3(zod@3.23.8) devDependencies: '@edge-runtime/vm': specifier: ^4.0.3 @@ -6564,6 +6567,7 @@ packages: eslint@8.57.0: resolution: {integrity: sha512-dZ6+mexnaTIbSBZWgou51U6OmzIhYM2VcNdtiTtI7qPNZm35Akpr0f6vtw3w1Kmn5PYo+tZVfh13WrhpS6oLqQ==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + deprecated: This version is no longer supported. Please see https://eslint.org/version-support for other options. hasBin: true eslint@9.10.0: @@ -11623,6 +11627,11 @@ packages: peerDependencies: zod: ^3.23.3 + zod-to-json-schema@3.23.3: + resolution: {integrity: sha512-TYWChTxKQbRJp5ST22o/Irt9KC5nj7CdBKYB/AosCRdj/wxEMvv4NNaj9XVUHDOIp53ZxArGhnw5HMZziPFjog==} + peerDependencies: + zod: ^3.23.3 + zod@3.23.8: resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==} @@ -19572,7 +19581,7 @@ snapshots: eslint-import-resolver-node@0.3.9: dependencies: debug: 3.2.7 - is-core-module: 2.15.1 + is-core-module: '@nolyfill/is-core-module@1.0.39' resolve: 1.22.8 transitivePeerDependencies: - supports-color @@ -25943,6 +25952,10 @@ snapshots: dependencies: zod: 3.23.8 + zod-to-json-schema@3.23.3(zod@3.23.8): + dependencies: + zod: 3.23.8 + zod@3.23.8: {} zwitch@2.0.4: {} -- GitLab