Skip to content
Snippets Groups Projects
Unverified Commit 968feb32 authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

feat: better input type for function tool with `zod` (#1464)

parent 43f6f56c
No related branches found
No related tags found
No related merge requests found
......@@ -167,6 +167,7 @@ For questions about more specific sections, please use the vector_tool.`,
const mockCall = t.mock.fn(({ query }: { query: string }) => {
return originalCall({ query });
});
// @ts-expect-error what?
queryEngineTools[1]!.call = mockCall;
const toolMapping = SimpleToolNodeMapping.fromObjects(queryEngineTools);
......
......@@ -4,18 +4,12 @@ import { zodToJsonSchema } from "zod-to-json-schema";
import type { JSONValue } from "../global";
import type { BaseTool, ToolMetadata } from "../llms";
const kOriginalFn = Symbol("originalFn");
export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
implements BaseTool<T>
{
[kOriginalFn]?: (input: T) => R;
#fn: (input: T) => R;
#metadata: ToolMetadata<JSONSchemaType<T>>;
// todo: for the future, we can use zod to validate the input parameters
// eslint-disable-next-line no-unused-private-class-members
#zodType: z.ZodType<T> | null = null;
readonly #metadata: ToolMetadata<JSONSchemaType<T>>;
readonly #zodType: z.ZodType<T> | null = null;
constructor(
fn: (input: T) => R,
metadata: ToolMetadata<JSONSchemaType<T>>,
......@@ -32,6 +26,12 @@ export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
fn: (input: T) => JSONValue | Promise<JSONValue>,
schema: ToolMetadata<JSONSchemaType<T>>,
): FunctionTool<T, JSONValue | Promise<JSONValue>>;
static from<R extends z.ZodType>(
fn: (input: z.infer<R>) => JSONValue | Promise<JSONValue>,
schema: Omit<ToolMetadata, "parameters"> & {
parameters: R;
},
): FunctionTool<z.infer<R>, JSONValue | Promise<JSONValue>>;
static from<T, R extends z.ZodType<T>>(
fn: (input: T) => JSONValue | Promise<JSONValue>,
schema: Omit<ToolMetadata, "parameters"> & {
......@@ -40,15 +40,15 @@ export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
): FunctionTool<T, JSONValue>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static from(fn: any, schema: any): any {
if (schema.parameter instanceof z.ZodSchema) {
const jsonSchema = zodToJsonSchema(schema.parameter);
if (schema.parameters instanceof z.ZodSchema) {
const jsonSchema = zodToJsonSchema(schema.parameters);
return new FunctionTool(
fn,
{
...schema,
parameters: jsonSchema,
},
schema.parameter,
schema.parameters,
);
}
return new FunctionTool(fn, schema);
......@@ -58,7 +58,15 @@ export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>>
return this.#metadata as BaseTool<T>["metadata"];
}
call(input: T) {
call = (input: T) => {
if (this.#zodType) {
const result = this.#zodType.safeParse(input);
if (result.success) {
return this.#fn.call(null, result.data);
} else {
console.warn(result.error.errors);
}
}
return this.#fn.call(null, input);
}
};
}
This diff is collapsed.
......@@ -26,6 +26,7 @@
"react-dom": "^18.3.1",
"tree-sitter": "^0.22.0",
"tree-sitter-javascript": "^0.23.0",
"tree-sitter-typescript": "^0.23.0"
"tree-sitter-typescript": "^0.23.0",
"zod": "^3.23.8"
}
}
import { FunctionTool } from "@llamaindex/core/tools";
import { describe, expect, test } from "vitest";
import { z } from "zod";
describe("function-tool", () => {
test("zod type check", () => {
const inputSchema = z.object({
name: z.string(),
age: z.coerce.number(),
});
const tool = FunctionTool.from(
(input) => {
if (typeof input.age !== "number") {
throw new Error("Age should be a number");
}
return "Hello " + input.name + " " + input.age;
},
{
name: "get-user",
description: "Get user by name and age",
parameters: inputSchema,
},
);
{
const response = tool.call({ name: "John", age: 30 });
expect(response).toBe("Hello John 30");
}
{
// @ts-expect-error age should be a number
const response = tool.call({ name: "John", age: "30" });
expect(response).toBe("Hello John 30");
}
});
});
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment