From 48c19c6e6231bc1c79aa93b238f2faf7ea5346cd Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Wed, 12 Jun 2024 15:28:59 +0200
Subject: [PATCH] fix: impove OpenAPI tool for TS

---
 .../engines/typescript/agent/tools/index.ts   |  29 ++-
 .../typescript/agent/tools/openapi-action.ts  | 177 ++++++++----------
 2 files changed, 92 insertions(+), 114 deletions(-)

diff --git a/templates/components/engines/typescript/agent/tools/index.ts b/templates/components/engines/typescript/agent/tools/index.ts
index cacdb60e..7f823eca 100644
--- a/templates/components/engines/typescript/agent/tools/index.ts
+++ b/templates/components/engines/typescript/agent/tools/index.ts
@@ -1,54 +1,53 @@
 import { BaseToolWithCall } from "llamaindex";
 import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
 import { InterpreterTool, InterpreterToolParams } from "./interpreter";
-import { OpenAPIActionToolSpec } from "./openapi-action";
+import { OpenAPIActionTool } from "./openapi-action";
 import { WeatherTool, WeatherToolParams } from "./weather";
 
-type ToolCreator = (config: unknown) => BaseToolWithCall[];
+type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
 
 export async function createTools(toolConfig: {
   local: Record<string, unknown>;
   llamahub: any;
 }): Promise<BaseToolWithCall[]> {
   // add local tools from the 'tools' folder (if configured)
-  const tools = createLocalTools(toolConfig.local);
+  const tools = await createLocalTools(toolConfig.local);
   // add tools from LlamaIndexTS (if configured)
   tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
   return tools;
 }
 
 const toolFactory: Record<string, ToolCreator> = {
-  weather: (config: unknown) => {
+  weather: async (config: unknown) => {
     return [new WeatherTool(config as WeatherToolParams)];
   },
-  interpreter: (config: unknown) => {
+  interpreter: async (config: unknown) => {
     return [new InterpreterTool(config as InterpreterToolParams)];
   },
-  "openapi_action.OpenAPIActionToolSpec": (config: unknown) => {
+  "openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
     const { openapi_uri, domain_headers } = config as {
       openapi_uri: string;
       domain_headers: Record<string, Record<string, string>>;
     };
-    const openAPIActionTool = new OpenAPIActionToolSpec(
+    const openAPIActionTool = new OpenAPIActionTool(
       openapi_uri,
       domain_headers,
     );
-    return openAPIActionTool.toToolFunctions();
+    return await openAPIActionTool.toToolFunctions();
   },
 };
 
-function createLocalTools(
+async function createLocalTools(
   localConfig: Record<string, unknown>,
-): BaseToolWithCall[] {
+): Promise<BaseToolWithCall[]> {
   const tools: BaseToolWithCall[] = [];
 
-  Object.keys(localConfig).forEach((key) => {
+  for (const [key, toolConfig] of Object.entries(localConfig)) {
     if (key in toolFactory) {
-      const toolConfig = localConfig[key];
-      const tool = toolFactory[key](toolConfig);
-      tools.push(...tool);
+      const newTools = await toolFactory[key](toolConfig);
+      tools.push(...newTools);
     }
-  });
+  }
 
   return tools;
 }
diff --git a/templates/components/engines/typescript/agent/tools/openapi-action.ts b/templates/components/engines/typescript/agent/tools/openapi-action.ts
index e4c044bd..74bb5bd8 100644
--- a/templates/components/engines/typescript/agent/tools/openapi-action.ts
+++ b/templates/components/engines/typescript/agent/tools/openapi-action.ts
@@ -1,109 +1,78 @@
 import SwaggerParser from "@apidevtools/swagger-parser";
+import { JSONSchemaType } from "ajv";
 import got from "got";
-import { FunctionTool, JSONValue } from "llamaindex";
+import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
 
 interface DomainHeaders {
   [key: string]: { [header: string]: string };
 }
 
-export class OpenAPIActionToolSpec {
+type Input = {
+  url: string;
+  params: object;
+};
+
+type APIInfo = {
+  description: string;
+  title: string;
+};
+
+export class OpenAPIActionTool {
+  // cache the loaded specs by URL
+  private static specs: Record<string, any> = {};
+
   private readonly INVALID_URL_PROMPT =
     "This url did not include a hostname or scheme. Please determine the complete URL and try again.";
 
-  private readonly LOAD_OPENAPI_SPEC = {
-    name: "load_openapi_spec",
-    description: "Use this function to load spec first before making requests.",
-  } as const;
-
-  private readonly GET_REQUEST_SPEC = {
-    name: "get_request",
-    description: "Use this to GET content from an url.",
-    parameters: {
-      type: "object",
-      properties: {
-        url: {
-          type: "string",
-          description: "The url to make the get request against",
-        },
-        params: {
-          type: "object",
-          description: "the parameters to provide with the get request",
-        },
-      },
-      required: ["url"],
-    },
-  } as const;
-
-  private readonly POST_REQUEST_SPEC = {
-    name: "post_request",
-    description: "Use this to POST content to an url.",
-    parameters: {
-      type: "object",
-      properties: {
-        url: {
-          type: "string",
-          description: "The url to make the get request against",
-        },
-        data: {
-          type: "object",
-          description: "the key-value pairs to provide with the get request",
-        },
-      },
-      required: ["url"],
-    },
-  } as const;
-
-  private readonly PATCH_REQUEST_SPEC = {
-    name: "patch_request",
-    description: "Use this to PATCH content to an url.",
-    parameters: {
-      type: "object",
-      properties: {
-        url: {
-          type: "string",
-          description: "The url to make the get request against",
-        },
-        data: {
-          type: "object",
-          description: "the key-value pairs to provide with the get request",
+  private createLoadSpecMetaData = (info: APIInfo) => {
+    return {
+      name: "load_openapi_spec",
+      description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
+    };
+  };
+
+  private readonly createMethodCallMetaData = (
+    method: "POST" | "PATCH" | "GET",
+    info: APIInfo,
+  ) => {
+    return {
+      name: `${method.toLowerCase()}_request`,
+      description: `Use this to call the ${method} method on the API named ${info.title}`,
+      parameters: {
+        type: "object",
+        properties: {
+          url: {
+            type: "string",
+            description: `The url to make the ${method} request against`,
+          },
+          params: {
+            type: "object",
+            description:
+              method === "GET"
+                ? "the URL parameters to provide with the get request"
+                : `the key-value pairs to provide with the ${method} request`,
+          },
         },
+        required: ["url"],
       },
-      required: ["url"],
-    },
-  } as const;
+    } as ToolMetadata<JSONSchemaType<Input>>;
+  };
 
   constructor(
     public openapi_uri: string,
     public domainHeaders: DomainHeaders = {},
   ) {}
 
-  async loadOpenapiSpec(): Promise<any> {
-    try {
-      const api = (await SwaggerParser.validate(this.openapi_uri)) as any;
-      return {
-        servers: api.servers,
-        description: api.info.description,
-        endpoints: api.paths,
-      };
-    } catch (err) {
-      return err;
-    }
+  async loadOpenapiSpec(url: string): Promise<any> {
+    const api = await SwaggerParser.validate(url);
+    return {
+      servers: "servers" in api ? api.servers : "",
+      info: { description: api.info.description, title: api.info.title },
+      endpoints: api.paths,
+    };
   }
 
-  async loadOpenapiSpecFromUrl({ url }: { url: string }): Promise<any> {
-    try {
-      const api = (await SwaggerParser.validate(url)) as any;
-      return {
-        servers: api.servers,
-        description: api.info.description,
-        endpoints: api.paths,
-      };
-    } catch (err) {
-      return err;
-    }
-  }
-
-  async getRequest(input: { url: string; params: object }): Promise<JSONValue> {
+  async getRequest(input: Input): Promise<JSONValue> {
     if (!this.validUrl(input.url)) {
       return this.INVALID_URL_PROMPT;
     }
@@ -120,14 +89,14 @@ export class OpenAPIActionToolSpec {
     }
   }
 
-  async postRequest(input: { url: string; data: object }): Promise<JSONValue> {
+  async postRequest(input: Input): Promise<JSONValue> {
     if (!this.validUrl(input.url)) {
       return this.INVALID_URL_PROMPT;
     }
     try {
       const res = await got.post(input.url, {
         headers: this.getHeadersForUrl(input.url),
-        json: input.data,
+        json: input.params,
       });
       return res.body as JSONValue;
     } catch (error) {
@@ -135,14 +104,14 @@ export class OpenAPIActionToolSpec {
     }
   }
 
-  async patchRequest(input: { url: string; data: object }): Promise<JSONValue> {
+  async patchRequest(input: Input): Promise<JSONValue> {
     if (!this.validUrl(input.url)) {
       return this.INVALID_URL_PROMPT;
     }
     try {
       const res = await got.patch(input.url, {
         headers: this.getHeadersForUrl(input.url),
-        json: input.data,
+        json: input.params,
       });
       return res.body as JSONValue;
     } catch (error) {
@@ -150,23 +119,33 @@ export class OpenAPIActionToolSpec {
     }
   }
 
-  public toToolFunctions = () => {
+  public async toToolFunctions() {
+    if (!OpenAPIActionTool.specs[this.openapi_uri]) {
+      console.log(`Loading spec for URL: ${this.openapi_uri}`);
+      const spec = await this.loadOpenapiSpec(this.openapi_uri);
+      OpenAPIActionTool.specs[this.openapi_uri] = spec;
+    }
+    const spec = OpenAPIActionTool.specs[this.openapi_uri];
+    // TODO: read endpoints with parameters from spec and create one tool for each endpoint
+    // For now, we just create a tool for each HTTP method which does not work well for passing parameters
     return [
-      FunctionTool.from(() => this.loadOpenapiSpec(), this.LOAD_OPENAPI_SPEC),
+      FunctionTool.from(() => {
+        return spec;
+      }, this.createLoadSpecMetaData(spec.info)),
       FunctionTool.from(
-        (input: { url: string; params: object }) => this.getRequest(input),
-        this.GET_REQUEST_SPEC,
+        this.getRequest.bind(this),
+        this.createMethodCallMetaData("GET", spec.info),
       ),
       FunctionTool.from(
-        (input: { url: string; data: object }) => this.postRequest(input),
-        this.POST_REQUEST_SPEC,
+        this.postRequest.bind(this),
+        this.createMethodCallMetaData("POST", spec.info),
       ),
       FunctionTool.from(
-        (input: { url: string; data: object }) => this.patchRequest(input),
-        this.PATCH_REQUEST_SPEC,
+        this.patchRequest.bind(this),
+        this.createMethodCallMetaData("PATCH", spec.info),
       ),
     ];
-  };
+  }
 
   private validUrl(url: string): boolean {
     const parsed = new URL(url);
-- 
GitLab