Skip to content
Snippets Groups Projects
Commit 48c19c6e authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

fix: impove OpenAPI tool for TS

parent d75c08e7
No related branches found
No related tags found
No related merge requests found
import { BaseToolWithCall } from "llamaindex"; import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import { InterpreterTool, InterpreterToolParams } from "./interpreter"; import { InterpreterTool, InterpreterToolParams } from "./interpreter";
import { OpenAPIActionToolSpec } from "./openapi-action"; import { OpenAPIActionTool } from "./openapi-action";
import { WeatherTool, WeatherToolParams } from "./weather"; import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall[]; type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
export async function createTools(toolConfig: { export async function createTools(toolConfig: {
local: Record<string, unknown>; local: Record<string, unknown>;
llamahub: any; llamahub: any;
}): Promise<BaseToolWithCall[]> { }): Promise<BaseToolWithCall[]> {
// add local tools from the 'tools' folder (if configured) // add local tools from the 'tools' folder (if configured)
const tools = createLocalTools(toolConfig.local); const tools = await createLocalTools(toolConfig.local);
// add tools from LlamaIndexTS (if configured) // add tools from LlamaIndexTS (if configured)
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub))); tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
return tools; return tools;
} }
const toolFactory: Record<string, ToolCreator> = { const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => { weather: async (config: unknown) => {
return [new WeatherTool(config as WeatherToolParams)]; return [new WeatherTool(config as WeatherToolParams)];
}, },
interpreter: (config: unknown) => { interpreter: async (config: unknown) => {
return [new InterpreterTool(config as InterpreterToolParams)]; return [new InterpreterTool(config as InterpreterToolParams)];
}, },
"openapi_action.OpenAPIActionToolSpec": (config: unknown) => { "openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
const { openapi_uri, domain_headers } = config as { const { openapi_uri, domain_headers } = config as {
openapi_uri: string; openapi_uri: string;
domain_headers: Record<string, Record<string, string>>; domain_headers: Record<string, Record<string, string>>;
}; };
const openAPIActionTool = new OpenAPIActionToolSpec( const openAPIActionTool = new OpenAPIActionTool(
openapi_uri, openapi_uri,
domain_headers, domain_headers,
); );
return openAPIActionTool.toToolFunctions(); return await openAPIActionTool.toToolFunctions();
}, },
}; };
function createLocalTools( async function createLocalTools(
localConfig: Record<string, unknown>, localConfig: Record<string, unknown>,
): BaseToolWithCall[] { ): Promise<BaseToolWithCall[]> {
const tools: BaseToolWithCall[] = []; const tools: BaseToolWithCall[] = [];
Object.keys(localConfig).forEach((key) => { for (const [key, toolConfig] of Object.entries(localConfig)) {
if (key in toolFactory) { if (key in toolFactory) {
const toolConfig = localConfig[key]; const newTools = await toolFactory[key](toolConfig);
const tool = toolFactory[key](toolConfig); tools.push(...newTools);
tools.push(...tool);
} }
}); }
return tools; return tools;
} }
import SwaggerParser from "@apidevtools/swagger-parser"; import SwaggerParser from "@apidevtools/swagger-parser";
import { JSONSchemaType } from "ajv";
import got from "got"; import got from "got";
import { FunctionTool, JSONValue } from "llamaindex"; import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
interface DomainHeaders { interface DomainHeaders {
[key: string]: { [header: string]: string }; [key: string]: { [header: string]: string };
} }
export class OpenAPIActionToolSpec { type Input = {
url: string;
params: object;
};
type APIInfo = {
description: string;
title: string;
};
export class OpenAPIActionTool {
// cache the loaded specs by URL
private static specs: Record<string, any> = {};
private readonly INVALID_URL_PROMPT = private readonly INVALID_URL_PROMPT =
"This url did not include a hostname or scheme. Please determine the complete URL and try again."; "This url did not include a hostname or scheme. Please determine the complete URL and try again.";
private readonly LOAD_OPENAPI_SPEC = { private createLoadSpecMetaData = (info: APIInfo) => {
name: "load_openapi_spec", return {
description: "Use this function to load spec first before making requests.", name: "load_openapi_spec",
} as const; description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
};
private readonly GET_REQUEST_SPEC = { };
name: "get_request",
description: "Use this to GET content from an url.", private readonly createMethodCallMetaData = (
parameters: { method: "POST" | "PATCH" | "GET",
type: "object", info: APIInfo,
properties: { ) => {
url: { return {
type: "string", name: `${method.toLowerCase()}_request`,
description: "The url to make the get request against", description: `Use this to call the ${method} method on the API named ${info.title}`,
}, parameters: {
params: { type: "object",
type: "object", properties: {
description: "the parameters to provide with the get request", url: {
}, type: "string",
}, description: `The url to make the ${method} request against`,
required: ["url"], },
}, params: {
} as const; type: "object",
description:
private readonly POST_REQUEST_SPEC = { method === "GET"
name: "post_request", ? "the URL parameters to provide with the get request"
description: "Use this to POST content to an url.", : `the key-value pairs to provide with the ${method} request`,
parameters: { },
type: "object",
properties: {
url: {
type: "string",
description: "The url to make the get request against",
},
data: {
type: "object",
description: "the key-value pairs to provide with the get request",
},
},
required: ["url"],
},
} as const;
private readonly PATCH_REQUEST_SPEC = {
name: "patch_request",
description: "Use this to PATCH content to an url.",
parameters: {
type: "object",
properties: {
url: {
type: "string",
description: "The url to make the get request against",
},
data: {
type: "object",
description: "the key-value pairs to provide with the get request",
}, },
required: ["url"],
}, },
required: ["url"], } as ToolMetadata<JSONSchemaType<Input>>;
}, };
} as const;
constructor( constructor(
public openapi_uri: string, public openapi_uri: string,
public domainHeaders: DomainHeaders = {}, public domainHeaders: DomainHeaders = {},
) {} ) {}
async loadOpenapiSpec(): Promise<any> { async loadOpenapiSpec(url: string): Promise<any> {
try { const api = await SwaggerParser.validate(url);
const api = (await SwaggerParser.validate(this.openapi_uri)) as any; return {
return { servers: "servers" in api ? api.servers : "",
servers: api.servers, info: { description: api.info.description, title: api.info.title },
description: api.info.description, endpoints: api.paths,
endpoints: api.paths, };
};
} catch (err) {
return err;
}
} }
async loadOpenapiSpecFromUrl({ url }: { url: string }): Promise<any> { async getRequest(input: Input): Promise<JSONValue> {
try {
const api = (await SwaggerParser.validate(url)) as any;
return {
servers: api.servers,
description: api.info.description,
endpoints: api.paths,
};
} catch (err) {
return err;
}
}
async getRequest(input: { url: string; params: object }): Promise<JSONValue> {
if (!this.validUrl(input.url)) { if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT; return this.INVALID_URL_PROMPT;
} }
...@@ -120,14 +89,14 @@ export class OpenAPIActionToolSpec { ...@@ -120,14 +89,14 @@ export class OpenAPIActionToolSpec {
} }
} }
async postRequest(input: { url: string; data: object }): Promise<JSONValue> { async postRequest(input: Input): Promise<JSONValue> {
if (!this.validUrl(input.url)) { if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT; return this.INVALID_URL_PROMPT;
} }
try { try {
const res = await got.post(input.url, { const res = await got.post(input.url, {
headers: this.getHeadersForUrl(input.url), headers: this.getHeadersForUrl(input.url),
json: input.data, json: input.params,
}); });
return res.body as JSONValue; return res.body as JSONValue;
} catch (error) { } catch (error) {
...@@ -135,14 +104,14 @@ export class OpenAPIActionToolSpec { ...@@ -135,14 +104,14 @@ export class OpenAPIActionToolSpec {
} }
} }
async patchRequest(input: { url: string; data: object }): Promise<JSONValue> { async patchRequest(input: Input): Promise<JSONValue> {
if (!this.validUrl(input.url)) { if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT; return this.INVALID_URL_PROMPT;
} }
try { try {
const res = await got.patch(input.url, { const res = await got.patch(input.url, {
headers: this.getHeadersForUrl(input.url), headers: this.getHeadersForUrl(input.url),
json: input.data, json: input.params,
}); });
return res.body as JSONValue; return res.body as JSONValue;
} catch (error) { } catch (error) {
...@@ -150,23 +119,33 @@ export class OpenAPIActionToolSpec { ...@@ -150,23 +119,33 @@ export class OpenAPIActionToolSpec {
} }
} }
public toToolFunctions = () => { public async toToolFunctions() {
if (!OpenAPIActionTool.specs[this.openapi_uri]) {
console.log(`Loading spec for URL: ${this.openapi_uri}`);
const spec = await this.loadOpenapiSpec(this.openapi_uri);
OpenAPIActionTool.specs[this.openapi_uri] = spec;
}
const spec = OpenAPIActionTool.specs[this.openapi_uri];
// TODO: read endpoints with parameters from spec and create one tool for each endpoint
// For now, we just create a tool for each HTTP method which does not work well for passing parameters
return [ return [
FunctionTool.from(() => this.loadOpenapiSpec(), this.LOAD_OPENAPI_SPEC), FunctionTool.from(() => {
return spec;
}, this.createLoadSpecMetaData(spec.info)),
FunctionTool.from( FunctionTool.from(
(input: { url: string; params: object }) => this.getRequest(input), this.getRequest.bind(this),
this.GET_REQUEST_SPEC, this.createMethodCallMetaData("GET", spec.info),
), ),
FunctionTool.from( FunctionTool.from(
(input: { url: string; data: object }) => this.postRequest(input), this.postRequest.bind(this),
this.POST_REQUEST_SPEC, this.createMethodCallMetaData("POST", spec.info),
), ),
FunctionTool.from( FunctionTool.from(
(input: { url: string; data: object }) => this.patchRequest(input), this.patchRequest.bind(this),
this.PATCH_REQUEST_SPEC, this.createMethodCallMetaData("PATCH", spec.info),
), ),
]; ];
}; }
private validUrl(url: string): boolean { private validUrl(url: string): boolean {
const parsed = new URL(url); const parsed = new URL(url);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment