From 48c19c6e6231bc1c79aa93b238f2faf7ea5346cd Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Wed, 12 Jun 2024 15:28:59 +0200 Subject: [PATCH] fix: impove OpenAPI tool for TS --- .../engines/typescript/agent/tools/index.ts | 29 ++- .../typescript/agent/tools/openapi-action.ts | 177 ++++++++---------- 2 files changed, 92 insertions(+), 114 deletions(-) diff --git a/templates/components/engines/typescript/agent/tools/index.ts b/templates/components/engines/typescript/agent/tools/index.ts index cacdb60e..7f823eca 100644 --- a/templates/components/engines/typescript/agent/tools/index.ts +++ b/templates/components/engines/typescript/agent/tools/index.ts @@ -1,54 +1,53 @@ import { BaseToolWithCall } from "llamaindex"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; import { InterpreterTool, InterpreterToolParams } from "./interpreter"; -import { OpenAPIActionToolSpec } from "./openapi-action"; +import { OpenAPIActionTool } from "./openapi-action"; import { WeatherTool, WeatherToolParams } from "./weather"; -type ToolCreator = (config: unknown) => BaseToolWithCall[]; +type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>; export async function createTools(toolConfig: { local: Record<string, unknown>; llamahub: any; }): Promise<BaseToolWithCall[]> { // add local tools from the 'tools' folder (if configured) - const tools = createLocalTools(toolConfig.local); + const tools = await createLocalTools(toolConfig.local); // add tools from LlamaIndexTS (if configured) tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub))); return tools; } const toolFactory: Record<string, ToolCreator> = { - weather: (config: unknown) => { + weather: async (config: unknown) => { return [new WeatherTool(config as WeatherToolParams)]; }, - interpreter: (config: unknown) => { + interpreter: async (config: unknown) => { return [new InterpreterTool(config as InterpreterToolParams)]; }, - "openapi_action.OpenAPIActionToolSpec": (config: unknown) => { + "openapi_action.OpenAPIActionToolSpec": async (config: unknown) => { const { openapi_uri, domain_headers } = config as { openapi_uri: string; domain_headers: Record<string, Record<string, string>>; }; - const openAPIActionTool = new OpenAPIActionToolSpec( + const openAPIActionTool = new OpenAPIActionTool( openapi_uri, domain_headers, ); - return openAPIActionTool.toToolFunctions(); + return await openAPIActionTool.toToolFunctions(); }, }; -function createLocalTools( +async function createLocalTools( localConfig: Record<string, unknown>, -): BaseToolWithCall[] { +): Promise<BaseToolWithCall[]> { const tools: BaseToolWithCall[] = []; - Object.keys(localConfig).forEach((key) => { + for (const [key, toolConfig] of Object.entries(localConfig)) { if (key in toolFactory) { - const toolConfig = localConfig[key]; - const tool = toolFactory[key](toolConfig); - tools.push(...tool); + const newTools = await toolFactory[key](toolConfig); + tools.push(...newTools); } - }); + } return tools; } diff --git a/templates/components/engines/typescript/agent/tools/openapi-action.ts b/templates/components/engines/typescript/agent/tools/openapi-action.ts index e4c044bd..74bb5bd8 100644 --- a/templates/components/engines/typescript/agent/tools/openapi-action.ts +++ b/templates/components/engines/typescript/agent/tools/openapi-action.ts @@ -1,109 +1,78 @@ import SwaggerParser from "@apidevtools/swagger-parser"; +import { JSONSchemaType } from "ajv"; import got from "got"; -import { FunctionTool, JSONValue } from "llamaindex"; +import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex"; interface DomainHeaders { [key: string]: { [header: string]: string }; } -export class OpenAPIActionToolSpec { +type Input = { + url: string; + params: object; +}; + +type APIInfo = { + description: string; + title: string; +}; + +export class OpenAPIActionTool { + // cache the loaded specs by URL + private static specs: Record<string, any> = {}; + private readonly INVALID_URL_PROMPT = "This url did not include a hostname or scheme. Please determine the complete URL and try again."; - private readonly LOAD_OPENAPI_SPEC = { - name: "load_openapi_spec", - description: "Use this function to load spec first before making requests.", - } as const; - - private readonly GET_REQUEST_SPEC = { - name: "get_request", - description: "Use this to GET content from an url.", - parameters: { - type: "object", - properties: { - url: { - type: "string", - description: "The url to make the get request against", - }, - params: { - type: "object", - description: "the parameters to provide with the get request", - }, - }, - required: ["url"], - }, - } as const; - - private readonly POST_REQUEST_SPEC = { - name: "post_request", - description: "Use this to POST content to an url.", - parameters: { - type: "object", - properties: { - url: { - type: "string", - description: "The url to make the get request against", - }, - data: { - type: "object", - description: "the key-value pairs to provide with the get request", - }, - }, - required: ["url"], - }, - } as const; - - private readonly PATCH_REQUEST_SPEC = { - name: "patch_request", - description: "Use this to PATCH content to an url.", - parameters: { - type: "object", - properties: { - url: { - type: "string", - description: "The url to make the get request against", - }, - data: { - type: "object", - description: "the key-value pairs to provide with the get request", + private createLoadSpecMetaData = (info: APIInfo) => { + return { + name: "load_openapi_spec", + description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`, + }; + }; + + private readonly createMethodCallMetaData = ( + method: "POST" | "PATCH" | "GET", + info: APIInfo, + ) => { + return { + name: `${method.toLowerCase()}_request`, + description: `Use this to call the ${method} method on the API named ${info.title}`, + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: `The url to make the ${method} request against`, + }, + params: { + type: "object", + description: + method === "GET" + ? "the URL parameters to provide with the get request" + : `the key-value pairs to provide with the ${method} request`, + }, }, + required: ["url"], }, - required: ["url"], - }, - } as const; + } as ToolMetadata<JSONSchemaType<Input>>; + }; constructor( public openapi_uri: string, public domainHeaders: DomainHeaders = {}, ) {} - async loadOpenapiSpec(): Promise<any> { - try { - const api = (await SwaggerParser.validate(this.openapi_uri)) as any; - return { - servers: api.servers, - description: api.info.description, - endpoints: api.paths, - }; - } catch (err) { - return err; - } + async loadOpenapiSpec(url: string): Promise<any> { + const api = await SwaggerParser.validate(url); + return { + servers: "servers" in api ? api.servers : "", + info: { description: api.info.description, title: api.info.title }, + endpoints: api.paths, + }; } - async loadOpenapiSpecFromUrl({ url }: { url: string }): Promise<any> { - try { - const api = (await SwaggerParser.validate(url)) as any; - return { - servers: api.servers, - description: api.info.description, - endpoints: api.paths, - }; - } catch (err) { - return err; - } - } - - async getRequest(input: { url: string; params: object }): Promise<JSONValue> { + async getRequest(input: Input): Promise<JSONValue> { if (!this.validUrl(input.url)) { return this.INVALID_URL_PROMPT; } @@ -120,14 +89,14 @@ export class OpenAPIActionToolSpec { } } - async postRequest(input: { url: string; data: object }): Promise<JSONValue> { + async postRequest(input: Input): Promise<JSONValue> { if (!this.validUrl(input.url)) { return this.INVALID_URL_PROMPT; } try { const res = await got.post(input.url, { headers: this.getHeadersForUrl(input.url), - json: input.data, + json: input.params, }); return res.body as JSONValue; } catch (error) { @@ -135,14 +104,14 @@ export class OpenAPIActionToolSpec { } } - async patchRequest(input: { url: string; data: object }): Promise<JSONValue> { + async patchRequest(input: Input): Promise<JSONValue> { if (!this.validUrl(input.url)) { return this.INVALID_URL_PROMPT; } try { const res = await got.patch(input.url, { headers: this.getHeadersForUrl(input.url), - json: input.data, + json: input.params, }); return res.body as JSONValue; } catch (error) { @@ -150,23 +119,33 @@ export class OpenAPIActionToolSpec { } } - public toToolFunctions = () => { + public async toToolFunctions() { + if (!OpenAPIActionTool.specs[this.openapi_uri]) { + console.log(`Loading spec for URL: ${this.openapi_uri}`); + const spec = await this.loadOpenapiSpec(this.openapi_uri); + OpenAPIActionTool.specs[this.openapi_uri] = spec; + } + const spec = OpenAPIActionTool.specs[this.openapi_uri]; + // TODO: read endpoints with parameters from spec and create one tool for each endpoint + // For now, we just create a tool for each HTTP method which does not work well for passing parameters return [ - FunctionTool.from(() => this.loadOpenapiSpec(), this.LOAD_OPENAPI_SPEC), + FunctionTool.from(() => { + return spec; + }, this.createLoadSpecMetaData(spec.info)), FunctionTool.from( - (input: { url: string; params: object }) => this.getRequest(input), - this.GET_REQUEST_SPEC, + this.getRequest.bind(this), + this.createMethodCallMetaData("GET", spec.info), ), FunctionTool.from( - (input: { url: string; data: object }) => this.postRequest(input), - this.POST_REQUEST_SPEC, + this.postRequest.bind(this), + this.createMethodCallMetaData("POST", spec.info), ), FunctionTool.from( - (input: { url: string; data: object }) => this.patchRequest(input), - this.PATCH_REQUEST_SPEC, + this.patchRequest.bind(this), + this.createMethodCallMetaData("PATCH", spec.info), ), ]; - }; + } private validUrl(url: string): boolean { const parsed = new URL(url); -- GitLab