diff --git a/.changeset/big-trainers-turn.md b/.changeset/big-trainers-turn.md new file mode 100644 index 0000000000000000000000000000000000000000..28ae78a0cd08372c1e3e327d4b102386da906939 --- /dev/null +++ b/.changeset/big-trainers-turn.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add OpenAPI tool for Typescript diff --git a/helpers/tools.ts b/helpers/tools.ts index 52b2a83aecb63af07f0b64490bab68b0495f3b19..dcfa5d06ed20277ef82be6aaca33a15ab48877ad 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -137,7 +137,7 @@ export const supportedTools: Tool[] = [ config: { openapi_uri: "The URL or file path of the OpenAPI schema", }, - supportedFrameworks: ["fastapi"], + supportedFrameworks: ["fastapi", "express", "nextjs"], type: ToolType.LOCAL, envVars: [ { diff --git a/templates/components/engines/typescript/agent/tools/index.ts b/templates/components/engines/typescript/agent/tools/index.ts index faff9b5dd2518904f0b9b01916fb71bf277878f7..cacdb60e5de630c9a45792da209b61072bd6b4a4 100644 --- a/templates/components/engines/typescript/agent/tools/index.ts +++ b/templates/components/engines/typescript/agent/tools/index.ts @@ -1,9 +1,10 @@ import { BaseToolWithCall } from "llamaindex"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; import { InterpreterTool, InterpreterToolParams } from "./interpreter"; +import { OpenAPIActionToolSpec } from "./openapi-action"; import { WeatherTool, WeatherToolParams } from "./weather"; -type ToolCreator = (config: unknown) => BaseToolWithCall; +type ToolCreator = (config: unknown) => BaseToolWithCall[]; export async function createTools(toolConfig: { local: Record<string, unknown>; @@ -18,10 +19,21 @@ export async function createTools(toolConfig: { const toolFactory: Record<string, ToolCreator> = { weather: (config: unknown) => { - return new WeatherTool(config as WeatherToolParams); + return [new WeatherTool(config as WeatherToolParams)]; }, interpreter: (config: unknown) => { - return new InterpreterTool(config as InterpreterToolParams); + return [new InterpreterTool(config as InterpreterToolParams)]; + }, + "openapi_action.OpenAPIActionToolSpec": (config: unknown) => { + const { openapi_uri, domain_headers } = config as { + openapi_uri: string; + domain_headers: Record<string, Record<string, string>>; + }; + const openAPIActionTool = new OpenAPIActionToolSpec( + openapi_uri, + domain_headers, + ); + return openAPIActionTool.toToolFunctions(); }, }; @@ -34,7 +46,7 @@ function createLocalTools( if (key in toolFactory) { const toolConfig = localConfig[key]; const tool = toolFactory[key](toolConfig); - tools.push(tool); + tools.push(...tool); } }); diff --git a/templates/components/engines/typescript/agent/tools/openapi-action.ts b/templates/components/engines/typescript/agent/tools/openapi-action.ts new file mode 100644 index 0000000000000000000000000000000000000000..e4c044bd4bc8bf515fc9a2e1ed96dd7b0bd9af0b --- /dev/null +++ b/templates/components/engines/typescript/agent/tools/openapi-action.ts @@ -0,0 +1,185 @@ +import SwaggerParser from "@apidevtools/swagger-parser"; +import got from "got"; +import { FunctionTool, JSONValue } from "llamaindex"; + +interface DomainHeaders { + [key: string]: { [header: string]: string }; +} + +export class OpenAPIActionToolSpec { + private readonly INVALID_URL_PROMPT = + "This url did not include a hostname or scheme. Please determine the complete URL and try again."; + + private readonly LOAD_OPENAPI_SPEC = { + name: "load_openapi_spec", + description: "Use this function to load spec first before making requests.", + } as const; + + private readonly GET_REQUEST_SPEC = { + name: "get_request", + description: "Use this to GET content from an url.", + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: "The url to make the get request against", + }, + params: { + type: "object", + description: "the parameters to provide with the get request", + }, + }, + required: ["url"], + }, + } as const; + + private readonly POST_REQUEST_SPEC = { + name: "post_request", + description: "Use this to POST content to an url.", + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: "The url to make the get request against", + }, + data: { + type: "object", + description: "the key-value pairs to provide with the get request", + }, + }, + required: ["url"], + }, + } as const; + + private readonly PATCH_REQUEST_SPEC = { + name: "patch_request", + description: "Use this to PATCH content to an url.", + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: "The url to make the get request against", + }, + data: { + type: "object", + description: "the key-value pairs to provide with the get request", + }, + }, + required: ["url"], + }, + } as const; + + constructor( + public openapi_uri: string, + public domainHeaders: DomainHeaders = {}, + ) {} + + async loadOpenapiSpec(): Promise<any> { + try { + const api = (await SwaggerParser.validate(this.openapi_uri)) as any; + return { + servers: api.servers, + description: api.info.description, + endpoints: api.paths, + }; + } catch (err) { + return err; + } + } + + async loadOpenapiSpecFromUrl({ url }: { url: string }): Promise<any> { + try { + const api = (await SwaggerParser.validate(url)) as any; + return { + servers: api.servers, + description: api.info.description, + endpoints: api.paths, + }; + } catch (err) { + return err; + } + } + + async getRequest(input: { url: string; params: object }): Promise<JSONValue> { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const data = await got + .get(input.url, { + headers: this.getHeadersForUrl(input.url), + searchParams: input.params as URLSearchParams, + }) + .json(); + return data as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + async postRequest(input: { url: string; data: object }): Promise<JSONValue> { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const res = await got.post(input.url, { + headers: this.getHeadersForUrl(input.url), + json: input.data, + }); + return res.body as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + async patchRequest(input: { url: string; data: object }): Promise<JSONValue> { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const res = await got.patch(input.url, { + headers: this.getHeadersForUrl(input.url), + json: input.data, + }); + return res.body as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + public toToolFunctions = () => { + return [ + FunctionTool.from(() => this.loadOpenapiSpec(), this.LOAD_OPENAPI_SPEC), + FunctionTool.from( + (input: { url: string; params: object }) => this.getRequest(input), + this.GET_REQUEST_SPEC, + ), + FunctionTool.from( + (input: { url: string; data: object }) => this.postRequest(input), + this.POST_REQUEST_SPEC, + ), + FunctionTool.from( + (input: { url: string; data: object }) => this.patchRequest(input), + this.PATCH_REQUEST_SPEC, + ), + ]; + }; + + private validUrl(url: string): boolean { + const parsed = new URL(url); + return !!parsed.protocol && !!parsed.hostname; + } + + private getDomain(url: string): string { + const parsed = new URL(url); + return parsed.hostname; + } + + private getHeadersForUrl(url: string): { [header: string]: string } { + const domain = this.getDomain(url); + return this.domainHeaders[domain] || {}; + } +} diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index 667ecc47ef53862c2b6393f343c52d437fef1100..312f708b81a847b49cf72827cc8177d741268d8b 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -17,7 +17,9 @@ "llamaindex": "0.3.16", "pdf2json": "3.0.5", "ajv": "^8.12.0", - "@e2b/code-interpreter": "^0.0.5" + "@e2b/code-interpreter": "^0.0.5", + "got": "10.7.0", + "@apidevtools/swagger-parser": "^10.1.0" }, "devDependencies": { "@types/cors": "^2.8.16", diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 7c9c548c395d8576fab54543d4fa43a702bd6faa..69515e6a7274178cbbf6c6f42466ccc55de4e0d7 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -36,7 +36,9 @@ "vaul": "^0.9.1", "@llamaindex/pdf-viewer": "^1.1.1", "@e2b/code-interpreter": "^0.0.5", - "uuid": "^9.0.1" + "uuid": "^9.0.1", + "got": "10.7.0", + "@apidevtools/swagger-parser": "^10.1.0" }, "devDependencies": { "@types/node": "^20.10.3",