diff --git a/packages/tools/.gitignore b/packages/tools/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9b1960e711fc1719053c5908bb5ff3cc321789e5 --- /dev/null +++ b/packages/tools/.gitignore @@ -0,0 +1 @@ +output/ \ No newline at end of file diff --git a/packages/tools/package.json b/packages/tools/package.json index 7dc2e797e4dc32d85db791a4e3d2ff080dfe6c90..8eec15934819bab93fef2ac7cc0c8cf761e39410 100644 --- a/packages/tools/package.json +++ b/packages/tools/package.json @@ -27,7 +27,9 @@ }, "scripts": { "build": "bunchee", - "dev": "bunchee --watch" + "dev": "bunchee --watch", + "test": "vitest run", + "test:watch": "vitest watch" }, "devDependencies": { "bunchee": "6.4.0", diff --git a/packages/tools/src/helper.ts b/packages/tools/src/helper.ts index 7898708231f65388d3080339e9af3595df18b0a9..c62c2abd53f540b5f26118c48ab9363f0903b1d6 100644 --- a/packages/tools/src/helper.ts +++ b/packages/tools/src/helper.ts @@ -1,24 +1,35 @@ import fs from "node:fs"; import path from "node:path"; -export async function saveDocument(filepath: string, content: string | Buffer) { - if (path.isAbsolute(filepath)) { +export async function saveDocument(filePath: string, content: string | Buffer) { + if (path.isAbsolute(filePath)) { throw new Error("Absolute file paths are not allowed."); } - if (!process.env.FILESERVER_URL_PREFIX) { - throw new Error("FILESERVER_URL_PREFIX environment variable is not set."); - } - const dirPath = path.dirname(filepath); + const dirPath = path.dirname(filePath); await fs.promises.mkdir(dirPath, { recursive: true }); if (typeof content === "string") { - await fs.promises.writeFile(filepath, content, "utf-8"); + await fs.promises.writeFile(filePath, content, "utf-8"); } else { - await fs.promises.writeFile(filepath, content); + await fs.promises.writeFile(filePath, content); + } +} + +export function getFileUrl( + filePath: string, + options?: { + fileServerURLPrefix?: string | undefined; + }, +): string { + const fileServerURLPrefix = + options?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX; + + if (!fileServerURLPrefix) { + throw new Error( + "To get a file URL, please provide a fileServerURLPrefix or set FILESERVER_URL_PREFIX environment variable.", + ); } - const fileurl = `${process.env.FILESERVER_URL_PREFIX}/${filepath}`; - console.log(`Saved document to ${filepath}. Reachable at URL: ${fileurl}`); - return fileurl; + return `${fileServerURLPrefix}/${filePath}`; } diff --git a/packages/tools/src/tools/code-generator.ts b/packages/tools/src/tools/code-generator.ts index 8e0302895f93cf5d4b4bb3f33ec5381518633f90..e98cf20f1d1a10cf2024b0e538ceecbd205ee30c 100644 --- a/packages/tools/src/tools/code-generator.ts +++ b/packages/tools/src/tools/code-generator.ts @@ -47,6 +47,11 @@ export type CodeArtifact = { files?: string[]; }; +export type CodeGeneratorToolOutput = { + isError: boolean; + artifact?: CodeArtifact; +}; + // Helper function async function generateArtifact( query: string, @@ -97,11 +102,7 @@ export const codeGenerator = () => { "A list of sandbox file paths. Include these files if the code requires them.", ), }), - execute: async ({ - requirement, - oldCode, - sandboxFiles, - }): Promise<JSONValue> => { + execute: async ({ requirement, oldCode, sandboxFiles }) => { try { const artifact = await generateArtifact( requirement, @@ -111,9 +112,14 @@ export const codeGenerator = () => { if (sandboxFiles) { artifact.files = sandboxFiles; } - return artifact as JSONValue; + return { + isError: false, + artifact, + } as JSONValue; } catch (error) { - return { isError: true }; + return { + isError: true, + }; } }, }); diff --git a/packages/tools/src/tools/document-generator.ts b/packages/tools/src/tools/document-generator.ts index cf00db07d19651a0142d88231e59cb0c02543747..7b52ada6cc5d4d581ade9f74f4673c31f2e1c79f 100644 --- a/packages/tools/src/tools/document-generator.ts +++ b/packages/tools/src/tools/document-generator.ts @@ -1,35 +1,8 @@ -import type { BaseTool, ToolMetadata } from "@llamaindex/core/llms"; -import { type JSONSchemaType } from "ajv"; +import { tool } from "@llamaindex/core/tools"; import { marked } from "marked"; import path from "node:path"; -import { saveDocument } from "../helper"; - -const OUTPUT_DIR = "output/tools"; - -type DocumentParameter = { - originalContent: string; - fileName: string; -}; - -const DEFAULT_METADATA: ToolMetadata<JSONSchemaType<DocumentParameter>> = { - name: "document_generator", - description: - "Generate HTML document from markdown content. Return a file url to the document", - parameters: { - type: "object", - properties: { - originalContent: { - type: "string", - description: "The original markdown content to convert.", - }, - fileName: { - type: "string", - description: "The name of the document file (without extension).", - }, - }, - required: ["originalContent", "fileName"], - }, -}; +import { z } from "zod"; +import { getFileUrl, saveDocument } from "../helper"; const COMMON_STYLES = ` body { @@ -103,39 +76,37 @@ const HTML_TEMPLATE = ` </html> `; -export interface DocumentGeneratorParams { - metadata?: ToolMetadata<JSONSchemaType<DocumentParameter>>; -} - -export class DocumentGenerator implements BaseTool<DocumentParameter> { - metadata: ToolMetadata<JSONSchemaType<DocumentParameter>>; - - constructor(params: DocumentGeneratorParams) { - this.metadata = params.metadata ?? DEFAULT_METADATA; - } - - private static async generateHtmlContent( - originalContent: string, - ): Promise<string> { - return await marked(originalContent); - } - - private static generateHtmlDocument(htmlContent: string): string { - return HTML_TEMPLATE.replace("{{content}}", htmlContent); - } - - async call(input: DocumentParameter): Promise<string> { - const { originalContent, fileName } = input; - - const htmlContent = - await DocumentGenerator.generateHtmlContent(originalContent); - const fileContent = DocumentGenerator.generateHtmlDocument(htmlContent); - - const filePath = path.join(OUTPUT_DIR, `${fileName}.html`); - - return `URL: ${await saveDocument(filePath, fileContent)}`; - } -} +export type DocumentGeneratorParams = { + /** Directory where generated documents will be saved */ + outputDir: string; + /** Prefix for the file server URL */ + fileServerURLPrefix?: string; +}; -export const documentGenerator = (params?: DocumentGeneratorParams) => - new DocumentGenerator(params ?? {}); +export const documentGenerator = (params: DocumentGeneratorParams) => { + return tool({ + name: "document_generator", + description: + "Generate HTML document from markdown content. Return a file url to the document", + parameters: z.object({ + originalContent: z + .string() + .describe("The original markdown content to convert"), + fileName: z + .string() + .describe("The name of the document file (without extension)"), + }), + execute: async ({ originalContent, fileName }): Promise<string> => { + const { outputDir, fileServerURLPrefix } = params; + + const htmlContent = await marked(originalContent); + const fileContent = HTML_TEMPLATE.replace("{{content}}", htmlContent); + + const filePath = path.join(outputDir, `${fileName}.html`); + await saveDocument(filePath, fileContent); + const fileUrl = getFileUrl(filePath, { fileServerURLPrefix }); + + return `URL: ${fileUrl}`; + }, + }); +}; diff --git a/packages/tools/src/tools/duckduckgo.ts b/packages/tools/src/tools/duckduckgo.ts index 628a3c7022a8979f4f100f0801c5a47dfc39d5e9..b2fc585d77ee2568695e76dd34680c84ebfae274 100644 --- a/packages/tools/src/tools/duckduckgo.ts +++ b/packages/tools/src/tools/duckduckgo.ts @@ -2,11 +2,11 @@ import { tool } from "@llamaindex/core/tools"; import { search } from "duck-duck-scrape"; import { z } from "zod"; -export type DuckDuckGoToolOutput = { +export type DuckDuckGoToolOutput = Array<{ title: string; description: string; url: string; -}[]; +}>; export const duckduckgo = () => { return tool({ diff --git a/packages/tools/src/tools/form-filling.ts b/packages/tools/src/tools/form-filling.ts index e1424f1d7833fb966e4c6dfe1eae67752657be8a..a8ab8e54eaa231f871285b3590a8752bf5a42133 100644 --- a/packages/tools/src/tools/form-filling.ts +++ b/packages/tools/src/tools/form-filling.ts @@ -1,14 +1,10 @@ import { Settings } from "@llamaindex/core/global"; -import type { BaseTool, ToolMetadata } from "@llamaindex/core/llms"; -import { type JSONSchemaType } from "ajv"; +import { tool } from "@llamaindex/core/tools"; import fs from "fs"; import Papa from "papaparse"; import path from "path"; -import { saveDocument } from "../helper"; - -type ExtractMissingCellsParameter = { - filePath: string; -}; +import { z } from "zod"; +import { getFileUrl, saveDocument } from "../helper"; export type MissingCell = { rowIndex: number; @@ -75,224 +71,164 @@ IMPORTANT: Column indices should be 0-based - Your answer: `; -const DEFAULT_METADATA: ToolMetadata< - JSONSchemaType<ExtractMissingCellsParameter> -> = { - name: "extract_missing_cells", - description: `Use this tool to extract missing cells in a CSV file and generate questions to fill them. This tool only works with local file path.`, - parameters: { - type: "object", - properties: { - filePath: { - type: "string", - description: "The local file path to the CSV file.", - }, - }, - required: ["filePath"], - }, -}; +export const extractMissingCells = () => { + return tool({ + name: "extract_missing_cells", + description: + "Use this tool to extract missing cells in a CSV file and generate questions to fill them. This tool only works with local file path.", + parameters: z.object({ + filePath: z.string().describe("The local file path to the CSV file."), + }), + execute: async ({ filePath }): Promise<MissingCell[]> => { + let tableContent: string[][]; + try { + tableContent = await readCsvFile(filePath); + } catch (error) { + throw new Error( + "Failed to read CSV file. Make sure that you are reading a local file path (not a sandbox path).", + ); + } -export interface ExtractMissingCellsParams { - metadata?: ToolMetadata<JSONSchemaType<ExtractMissingCellsParameter>>; -} + const prompt = CSV_EXTRACTION_PROMPT.replace( + "{table_content}", + formatToMarkdownTable(tableContent), + ); -export class ExtractMissingCellsTool - implements BaseTool<ExtractMissingCellsParameter> -{ - metadata: ToolMetadata<JSONSchemaType<ExtractMissingCellsParameter>>; - defaultExtractionPrompt: string; - - constructor(params: ExtractMissingCellsParams) { - this.metadata = params.metadata ?? DEFAULT_METADATA; - this.defaultExtractionPrompt = CSV_EXTRACTION_PROMPT; - } - - private readCsvFile(filePath: string): Promise<string[][]> { - return new Promise((resolve, reject) => { - fs.readFile(filePath, "utf8", (err, data) => { - if (err) { - reject(err); - return; - } + const llm = Settings.llm; + const response = await llm.complete({ prompt }); + const parsedResponse = JSON.parse(response.text) as { + missing_cells: MissingCell[]; + }; + if (!parsedResponse.missing_cells) { + throw new Error( + "The answer is not in the correct format. There should be a missing_cells array.", + ); + } + return parsedResponse.missing_cells; + }, + }); +}; - const parsedData = Papa.parse<string[]>(data, { - skipEmptyLines: false, - }); +export type FillMissingCellsParams = { + /** Directory where generated documents will be saved */ + outputDir: string; - if (parsedData.errors.length) { - reject(parsedData.errors); - return; - } + /** Prefix for the file server URL */ + fileServerURLPrefix?: string; +}; - // Ensure all rows have the same number of columns as the header - const maxColumns = parsedData.data[0]?.length ?? 0; - const paddedRows = parsedData.data.map((row) => { - return [...row, ...Array(maxColumns - row.length).fill("")]; - }); +export type FillMissingCellsToolOutput = { + isSuccess: boolean; + errorMessage?: string; + fileUrl?: string; +}; - resolve(paddedRows); +export const fillMissingCells = (params: FillMissingCellsParams) => { + return tool({ + name: "fill_missing_cells", + description: + "Use this tool to fill missing cells in a CSV file with provided answers. This tool only works with local file path.", + parameters: z.object({ + filePath: z.string().describe("The local file path to the CSV file."), + cells: z + .array( + z.object({ + rowIndex: z.number(), + columnIndex: z.number(), + answer: z.string(), + }), + ) + .describe("Array of cells to fill with their answers"), + }), + execute: async ({ filePath, cells }): Promise<string> => { + const { outputDir, fileServerURLPrefix } = params; + + // Read the CSV file + const fileContent = await fs.promises.readFile(filePath, "utf8"); + + // Parse CSV with PapaParse + const parseResult = Papa.parse<string[]>(fileContent, { + header: false, // Ensure the header is not treated as a separate object + skipEmptyLines: false, // Ensure empty lines are not skipped }); - }); - } - private formatToMarkdownTable(data: string[][]): string { - if (data.length === 0) return ""; + if (parseResult.errors.length) { + throw new Error( + "Failed to parse CSV file: " + parseResult.errors[0]?.message, + ); + } - const maxColumns = data[0]?.length ?? 0; + const rows = parseResult.data; + + // Fill the cells with answers + for (const cell of cells) { + // Adjust rowIndex to start from 1 for data rows + const adjustedRowIndex = cell.rowIndex + 1; + if ( + adjustedRowIndex < rows.length && + cell.columnIndex < (rows[adjustedRowIndex]?.length ?? 0) && + rows[adjustedRowIndex] + ) { + rows[adjustedRowIndex][cell.columnIndex] = cell.answer; + } + } - const headerRow = `| ${data[0]?.join(" | ") ?? ""} |`; - const separatorRow = `| ${Array(maxColumns).fill("---").join(" | ")} |`; + // Convert back to CSV format + const updatedContent = Papa.unparse(rows, { + delimiter: parseResult.meta.delimiter, + }); - const dataRows = data.slice(1).map((row) => { - return `| ${row.join(" | ")} |`; - }); + // Use the helper function to write the file + const parsedPath = path.parse(filePath); + const newFileName = `${parsedPath.name}-filled${parsedPath.ext}`; + const newFilePath = path.join(outputDir, newFileName); - return [headerRow, separatorRow, ...dataRows].join("\n"); - } - - async call(input: ExtractMissingCellsParameter): Promise<MissingCell[]> { - const { filePath } = input; - let tableContent: string[][]; - try { - tableContent = await this.readCsvFile(filePath); - } catch (error) { - throw new Error( - `Failed to read CSV file. Make sure that you are reading a local file path (not a sandbox path).`, - ); - } + await saveDocument(newFilePath, updatedContent); + const newFileUrl = getFileUrl(newFilePath, { fileServerURLPrefix }); - const prompt = this.defaultExtractionPrompt.replace( - "{table_content}", - this.formatToMarkdownTable(tableContent), - ); - - const llm = Settings.llm; - const response = await llm.complete({ - prompt, - }); - const rawAnswer = response.text; - const parsedResponse = JSON.parse(rawAnswer) as { - missing_cells: MissingCell[]; - }; - if (!parsedResponse.missing_cells) { - throw new Error( - "The answer is not in the correct format. There should be a missing_cells array.", + return ( + "Successfully filled missing cells in the CSV file. File URL to show to the user: " + + newFileUrl ); - } - const answer = parsedResponse.missing_cells; - - return answer; - } -} - -type FillMissingCellsParameter = { - filePath: string; - cells: { - rowIndex: number; - columnIndex: number; - answer: string; - }[]; -}; - -const FILL_CELLS_METADATA: ToolMetadata< - JSONSchemaType<FillMissingCellsParameter> -> = { - name: "fill_missing_cells", - description: `Use this tool to fill missing cells in a CSV file with provided answers. This tool only works with local file path.`, - parameters: { - type: "object", - properties: { - filePath: { - type: "string", - description: "The local file path to the CSV file.", - }, - cells: { - type: "array", - items: { - type: "object", - properties: { - rowIndex: { type: "number" }, - columnIndex: { type: "number" }, - answer: { type: "string" }, - }, - required: ["rowIndex", "columnIndex", "answer"], - }, - description: "Array of cells to fill with their answers", - }, }, - required: ["filePath", "cells"], - }, + }); }; -export interface FillMissingCellsParams { - metadata?: ToolMetadata<JSONSchemaType<FillMissingCellsParameter>>; -} +async function readCsvFile(filePath: string): Promise<string[][]> { + return new Promise((resolve, reject) => { + fs.readFile(filePath, "utf8", (err, data) => { + if (err) { + reject(err); + return; + } -export class FillMissingCellsTool - implements BaseTool<FillMissingCellsParameter> -{ - metadata: ToolMetadata<JSONSchemaType<FillMissingCellsParameter>>; - - constructor(params: FillMissingCellsParams = {}) { - this.metadata = params.metadata ?? FILL_CELLS_METADATA; - } - - async call(input: FillMissingCellsParameter): Promise<string> { - const { filePath, cells } = input; - - // Read the CSV file - const fileContent = await new Promise<string>((resolve, reject) => { - fs.readFile(filePath, "utf8", (err, data) => { - if (err) { - reject(err); - } else { - resolve(data); - } + const parsedData = Papa.parse<string[]>(data, { + skipEmptyLines: false, }); - }); - - // Parse CSV with PapaParse - const parseResult = Papa.parse<string[]>(fileContent, { - header: false, // Ensure the header is not treated as a separate object - skipEmptyLines: false, // Ensure empty lines are not skipped - }); - if (parseResult.errors.length) { - throw new Error( - "Failed to parse CSV file: " + parseResult.errors[0]?.message, - ); - } - - const rows = parseResult.data; - - // Fill the cells with answers - for (const cell of cells) { - // Adjust rowIndex to start from 1 for data rows - const adjustedRowIndex = cell.rowIndex + 1; - if ( - adjustedRowIndex < rows.length && - cell.columnIndex < (rows[adjustedRowIndex]?.length ?? 0) && - rows[adjustedRowIndex] - ) { - rows[adjustedRowIndex][cell.columnIndex] = cell.answer; + if (parsedData.errors.length) { + reject(parsedData.errors); + return; } - } - // Convert back to CSV format - const updatedContent = Papa.unparse(rows, { - delimiter: parseResult.meta.delimiter, + // Ensure all rows have the same number of columns as the header + const maxColumns = parsedData.data[0]?.length ?? 0; + const paddedRows = parsedData.data.map((row) => { + return [...row, ...Array(maxColumns - row.length).fill("")]; + }); + + resolve(paddedRows); }); + }); +} - // Use the helper function to write the file - const parsedPath = path.parse(filePath); - const newFileName = `${parsedPath.name}-filled${parsedPath.ext}`; - const newFilePath = path.join("output/tools", newFileName); +function formatToMarkdownTable(data: string[][]): string { + if (data.length === 0) return ""; - const newFileUrl = await saveDocument(newFilePath, updatedContent); + const maxColumns = data[0]?.length ?? 0; + const headerRow = `| ${data[0]?.join(" | ") ?? ""} |`; + const separatorRow = `| ${Array(maxColumns).fill("---").join(" | ")} |`; + const dataRows = data.slice(1).map((row) => `| ${row.join(" | ")} |`); - return ( - "Successfully filled missing cells in the CSV file. File URL to show to the user: " + - newFileUrl - ); - } + return [headerRow, separatorRow, ...dataRows].join("\n"); } diff --git a/packages/tools/src/tools/img-gen.ts b/packages/tools/src/tools/img-gen.ts index 2f7d32e35a5256ab8dea5efd1f79e95709ac710b..b46de53d338fde6d85017acc3db91459745fd3e7 100644 --- a/packages/tools/src/tools/img-gen.ts +++ b/packages/tools/src/tools/img-gen.ts @@ -1,11 +1,10 @@ import { tool } from "@llamaindex/core/tools"; import { FormData } from "formdata-node"; -import fs from "fs"; import got from "got"; -import crypto from "node:crypto"; -import path from "node:path"; +import path from "path"; import { Readable } from "stream"; import { z } from "zod"; +import { getFileUrl, saveDocument } from "../helper"; export type ImgGeneratorToolOutput = { isSuccess: boolean; @@ -14,83 +13,17 @@ export type ImgGeneratorToolOutput = { }; export type ImgGeneratorToolParams = { + /** Directory where generated images will be saved */ + outputDir: string; + /** STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys */ + apiKey: string; + /** Output format of the generated image */ outputFormat?: string; - outputDir?: string; - apiKey?: string; - fileServerURLPrefix?: string; + /** Prefix for the file server URL */ + fileServerURLPrefix?: string | undefined; }; -// Constants -const IMG_OUTPUT_FORMAT = "webp"; -const IMG_OUTPUT_DIR = "output/tools"; -const IMG_GEN_API = - "https://api.stability.ai/v2beta/stable-image/generate/core"; - -// Helper functions -function checkRequiredEnvVars() { - if (!process.env.STABILITY_API_KEY) { - throw new Error( - "STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys", - ); - } - if (!process.env.FILESERVER_URL_PREFIX) { - throw new Error( - "FILESERVER_URL_PREFIX is required to display file output after generation", - ); - } -} - -async function promptToImgBuffer( - prompt: string, - apiKey: string, -): Promise<Buffer> { - const form = new FormData(); - form.append("prompt", prompt); - form.append("output_format", IMG_OUTPUT_FORMAT); - - const buffer = await got - .post(IMG_GEN_API, { - body: form as unknown as Buffer | Readable | string, - headers: { - Authorization: `Bearer ${apiKey}`, - Accept: "image/*", - }, - }) - .buffer(); - - return buffer; -} - -function saveImage( - buffer: Buffer, - options: { - outputFormat?: string; - outputDir?: string; - fileServerURLPrefix?: string; - }, -): string { - const { - outputFormat = IMG_OUTPUT_FORMAT, - outputDir = IMG_OUTPUT_DIR, - fileServerURLPrefix = process.env.FILESERVER_URL_PREFIX, - } = options; - const filename = `${crypto.randomUUID()}.${outputFormat}`; - - // Create output directory if it doesn't exist - if (!fs.existsSync(outputDir)) { - fs.mkdirSync(outputDir, { recursive: true }); - } - - const outputPath = path.join(outputDir, filename); - fs.writeFileSync(outputPath, buffer); - - const url = `${fileServerURLPrefix}/${outputDir}/${filename}`; - console.log(`Saved image to ${outputPath}.\nURL: ${url}`); - - return url; -} - -export const imageGenerator = (params?: ImgGeneratorToolParams) => { +export const imageGenerator = (params: ImgGeneratorToolParams) => { return tool({ name: "image_generator", description: "Use this function to generate an image based on the prompt.", @@ -98,30 +31,15 @@ export const imageGenerator = (params?: ImgGeneratorToolParams) => { prompt: z.string().describe("The prompt to generate the image"), }), execute: async ({ prompt }): Promise<ImgGeneratorToolOutput> => { - const outputFormat = params?.outputFormat ?? IMG_OUTPUT_FORMAT; - const outputDir = params?.outputDir ?? IMG_OUTPUT_DIR; - const apiKey = params?.apiKey ?? process.env.STABILITY_API_KEY; - const fileServerURLPrefix = - params?.fileServerURLPrefix ?? process.env.FILESERVER_URL_PREFIX; - - if (!apiKey) { - throw new Error( - "STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys", - ); - } - if (!fileServerURLPrefix) { - throw new Error( - "FILESERVER_URL_PREFIX is required to display file output after generation", - ); - } + const { outputDir, apiKey, fileServerURLPrefix } = params; + const outputFormat = params.outputFormat ?? "webp"; try { - const buffer = await promptToImgBuffer(prompt, apiKey); - const imageUrl = saveImage(buffer, { - outputFormat, - outputDir, - fileServerURLPrefix, - }); + const buffer = await promptToImgBuffer(prompt, apiKey, outputFormat); + const filename = `${crypto.randomUUID()}.${outputFormat}`; + const filePath = path.join(outputDir, filename); + await saveDocument(filePath, buffer); + const imageUrl = getFileUrl(filePath, { fileServerURLPrefix }); return { isSuccess: true, imageUrl }; } catch (error) { console.error(error); @@ -133,3 +51,26 @@ export const imageGenerator = (params?: ImgGeneratorToolParams) => { }, }); }; + +async function promptToImgBuffer( + prompt: string, + apiKey: string, + outputFormat: string, +): Promise<Buffer> { + const form = new FormData(); + form.append("prompt", prompt); + form.append("output_format", outputFormat); + + const apiUrl = "https://api.stability.ai/v2beta/stable-image/generate/core"; + const buffer = await got + .post(apiUrl, { + body: form as unknown as Buffer | Readable | string, + headers: { + Authorization: `Bearer ${apiKey}`, + Accept: "image/*", + }, + }) + .buffer(); + + return buffer; +} diff --git a/packages/tools/src/tools/interpreter.ts b/packages/tools/src/tools/interpreter.ts index 354683f1b3840e34107fc8b83f710ed8864bd96a..364cdf6f425dc65c9acab0126717b37b28c88f85 100644 --- a/packages/tools/src/tools/interpreter.ts +++ b/packages/tools/src/tools/interpreter.ts @@ -1,34 +1,11 @@ import { type Logs, Result, Sandbox } from "@e2b/code-interpreter"; -import type { JSONValue } from "@llamaindex/core/global"; -import { type BaseTool, type ToolMetadata } from "@llamaindex/core/llms"; -import type { JSONSchemaType } from "ajv"; +import { tool } from "@llamaindex/core/tools"; import fs from "fs"; -import crypto from "node:crypto"; import path from "node:path"; +import { z } from "zod"; +import { getFileUrl, saveDocument } from "../helper"; -export type InterpreterParameter = { - code: string; - sandboxFiles?: string[]; - retryCount?: number; -}; - -export type InterpreterToolParams = { - metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>; - apiKey?: string | undefined; - fileServerURLPrefix?: string | undefined; - outputDir?: string | undefined; - uploadedFilesDir?: string | undefined; -}; - -export type InterpreterToolOutput = { - isError: boolean; - logs: Logs; - text?: string | undefined; - extraResult: InterpreterExtraResult[]; - retryCount?: number | undefined; -}; - -type InterpreterExtraType = +export type InterpreterExtraType = | "html" | "markdown" | "svg" @@ -46,212 +23,139 @@ export type InterpreterExtraResult = { url?: string; }; -const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = { - name: "interpreter", - description: `Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error. -If the code needs to use a file, ALWAYS pass the file path in the sandbox_files argument. -You have a maximum of 3 retries to get the code to run successfully. -`, - parameters: { - type: "object", - properties: { - code: { - type: "string", - description: "The python code to execute in a single cell.", - }, - sandboxFiles: { - type: "array", - description: - "List of local file paths to be used by the code. The tool will throw an error if a file is not found.", - items: { - type: "string", - }, - nullable: true, - }, - retryCount: { - type: "number", - description: "The number of times the tool has been retried", - default: 0, - nullable: true, - }, - }, - required: ["code"], - }, +export type InterpreterToolOutput = { + isError: boolean; + logs: Logs; + text?: string; + extraResult: InterpreterExtraResult[]; + retryCount?: number; }; -export class InterpreterTool implements BaseTool<InterpreterParameter> { - private outputDir: string; - private uploadedFilesDir: string; - private apiKey: string; - private fileServerURLPrefix: string; - metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>; - codeInterpreter?: Sandbox; - - constructor(params?: InterpreterToolParams) { - this.metadata = params?.metadata || DEFAULT_META_DATA; - this.apiKey = params?.apiKey || process.env.E2B_API_KEY!; - this.fileServerURLPrefix = - params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX!; - this.outputDir = params?.outputDir || "output/tools"; - this.uploadedFilesDir = params?.uploadedFilesDir || "output/uploaded"; - - if (!this.apiKey) { - throw new Error( - "E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key", - ); - } - if (!this.fileServerURLPrefix) { - throw new Error( - "FILESERVER_URL_PREFIX is required to display file output from sandbox", - ); - } - } - - public async initInterpreter(input: InterpreterParameter) { - if (!this.codeInterpreter) { - this.codeInterpreter = await Sandbox.create({ - apiKey: this.apiKey, - }); - // upload files to sandbox when it's initialized - if (input.sandboxFiles) { - console.log(`Uploading ${input.sandboxFiles.length} files to sandbox`); - try { - for (const filePath of input.sandboxFiles) { - const fileName = path.basename(filePath); - const localFilePath = path.join(this.uploadedFilesDir, fileName); - const content = fs.readFileSync(localFilePath); +export type InterpreterToolParams = { + /** E2B API key required for authentication. Get yours at https://e2b.dev/docs/legacy/getting-started/api-key */ + apiKey: string; + /** Directory where output files (charts, images, etc.) will be saved when code is executed */ + outputDir: string; + /** Local directory containing files that need to be uploaded to the sandbox environment before code execution */ + uploadedFilesDir: string; + /** Prefix for the file server URL */ + fileServerURLPrefix?: string; +}; - const arrayBuffer = new Uint8Array(content).buffer; - await this.codeInterpreter?.files.write(filePath, arrayBuffer); - } - } catch (error) { - console.error("Got error when uploading files to sandbox", error); - } +export const interpreter = (params: InterpreterToolParams) => { + const { apiKey, outputDir, uploadedFilesDir, fileServerURLPrefix } = params; + + return tool({ + name: "interpreter", + description: + "Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.", + parameters: z.object({ + code: z.string().describe("The python code to execute in a single cell"), + sandboxFiles: z + .array(z.string()) + .optional() + .describe("List of local file paths to be used by the code"), + retryCount: z + .number() + .default(0) + .optional() + .describe("The number of times the tool has been retried"), + }), + execute: async ({ code, sandboxFiles, retryCount = 0 }) => { + if (retryCount >= 3) { + return { + isError: true, + logs: { stdout: [], stderr: [] }, + text: "Max retries reached", + extraResult: [], + }; } - } - return this.codeInterpreter; - } - - public async codeInterpret( - input: InterpreterParameter, - ): Promise<InterpreterToolOutput> { - console.log( - `Sandbox files: ${input.sandboxFiles}. Retry count: ${input.retryCount}`, - ); + const interpreter = await Sandbox.create({ apiKey }); + await uploadFilesToSandbox(interpreter, uploadedFilesDir, sandboxFiles); + const exec = await interpreter.runCode(code); + const extraResult = await getExtraResult( + outputDir, + exec.results[0], + fileServerURLPrefix, + ); - if (input.retryCount && input.retryCount >= 3) { return { - isError: true, - logs: { - stdout: [], - stderr: [], - }, - text: "Max retries reached", - extraResult: [], - }; - } - - console.log( - `\n${"=".repeat(50)}\n> Running following AI-generated code:\n${input.code}\n${"=".repeat(50)}`, - ); - const interpreter = await this.initInterpreter(input); - const exec = await interpreter.runCode(input.code); - if (exec.error) console.error("[Code Interpreter error]", exec.error); - const extraResult = await this.getExtraResult(exec.results[0]); - const result: InterpreterToolOutput = { - isError: !!exec.error, - logs: exec.logs, - text: exec.text, - extraResult, - retryCount: input.retryCount ? input.retryCount + 1 : 1, - }; - return result; - } - - async call(input: InterpreterParameter) { - const result = await this.codeInterpret(input); - return result as JSONValue; - } + isError: !!exec.error, + logs: exec.logs, + text: exec.text, + extraResult, + retryCount: retryCount + 1, + } as InterpreterToolOutput; + }, + }); +}; - async close() { - await this.codeInterpreter?.kill(); +async function uploadFilesToSandbox( + codeInterpreter: Sandbox, + uploadedFilesDir: string, + sandboxFiles: string[] = [], +) { + try { + for (const filePath of sandboxFiles) { + const fileName = path.basename(filePath); + const localFilePath = path.join(uploadedFilesDir, fileName); + const content = fs.readFileSync(localFilePath); + const arrayBuffer = new Uint8Array(content).buffer; + await codeInterpreter.files.write(filePath, arrayBuffer); + } + } catch (error) { + console.error("Got error when uploading files to sandbox", error); } +} - private async getExtraResult( - res?: Result, - ): Promise<InterpreterExtraResult[]> { - if (!res) return []; - const output: InterpreterExtraResult[] = []; - - try { - const formats = res.formats(); // formats available for the result. Eg: ['png', ...] - const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format - - // save base64 data to file and return the url - for (let i = 0; i < formats.length; i++) { - const ext = formats[i]; - const data = results[i]; - switch (ext) { - case "png": - case "jpeg": - case "svg": - case "pdf": { - const { filename } = this.saveToDisk(data, ext); - output.push({ - type: ext as InterpreterExtraType, - filename, - url: this.getFileUrl(filename), - }); - break; - } - default: - output.push({ - type: ext as InterpreterExtraType, - content: data, - }); - break; +async function getExtraResult( + outputDir: string, + res?: Result, + fileServerURLPrefix?: string, +): Promise<InterpreterExtraResult[]> { + if (!res) return []; + const output: InterpreterExtraResult[] = []; + + try { + const formats = res.formats(); + const results = formats.map((f) => res[f as keyof Result]); + + for (let i = 0; i < formats.length; i++) { + const ext = formats[i]; + const data = results[i]; + switch (ext) { + case "png": + case "jpeg": + case "svg": + case "pdf": { + const { filename, filePath } = await saveToDisk(outputDir, data, ext); + const fileUrl = getFileUrl(filePath, { fileServerURLPrefix }); + output.push({ + type: ext as InterpreterExtraType, + filename, + url: fileUrl, + }); + break; } + default: + output.push({ + type: ext as InterpreterExtraType, + content: data, + }); + break; } - } catch (error) { - console.error("Error when parsing e2b response", error); } - - return output; - } - - // Consider saving to cloud storage instead but it may cost more for you - // See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file - private saveToDisk( - base64Data: string, - ext: string, - ): { - outputPath: string; - filename: string; - } { - const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename - const buffer = Buffer.from(base64Data, "base64"); - const outputPath = this.getOutputPath(filename); - fs.writeFileSync(outputPath, buffer); - console.log(`Saved file to ${outputPath}`); - return { - outputPath, - filename, - }; - } - - private getOutputPath(filename: string): string { - // if outputDir doesn't exist, create it - if (!fs.existsSync(this.outputDir)) { - fs.mkdirSync(this.outputDir, { recursive: true }); - } - return path.join(this.outputDir, filename); - } - - private getFileUrl(filename: string): string { - return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`; + } catch (error) { + console.error("Error when parsing e2b response", error); } + return output; } -export const interpreter = (params?: InterpreterToolParams) => - new InterpreterTool(params); +async function saveToDisk(outputDir: string, base64Data: string, ext: string) { + const filename = `${crypto.randomUUID()}.${ext}`; + const buffer = Buffer.from(base64Data, "base64"); + const filePath = path.join(outputDir, filename); + await saveDocument(filePath, buffer); + return { filename, filePath }; +} diff --git a/packages/tools/src/tools/openapi-action.ts b/packages/tools/src/tools/openapi-action.ts index 38dfbf317faae356bdd514cc182c5080bd67bfc0..506094a62c1b644125359d2f6eae226f076d7f31 100644 --- a/packages/tools/src/tools/openapi-action.ts +++ b/packages/tools/src/tools/openapi-action.ts @@ -166,3 +166,12 @@ export class OpenAPIActionTool { return this.domainHeaders[domain] || {}; } } + +export const getOpenAPIActionTools = async (params: { + openapiUri: string; + domainHeaders: DomainHeaders; +}) => { + const { openapiUri, domainHeaders } = params; + const openAPIActionTool = new OpenAPIActionTool(openapiUri, domainHeaders); + return await openAPIActionTool.toToolFunctions(); +}; diff --git a/packages/tools/src/tools/weather.ts b/packages/tools/src/tools/weather.ts index 62be492d46414c2d04092d61d362a478d500a92c..1e07e6cd91020057891e9dcc2bdbd238cf23be6e 100644 --- a/packages/tools/src/tools/weather.ts +++ b/packages/tools/src/tools/weather.ts @@ -64,31 +64,59 @@ export const weather = () => { parameters: z.object({ location: z.string().describe("The location to get the weather"), }), - execute: async ({ - location, - }: { - location: string; - }): Promise<WeatherToolOutput> => { + execute: async ({ location }): Promise<WeatherToolOutput> => { return await getWeatherByLocation(location); }, }); }; -async function getWeatherByLocation(location: string) { +async function getWeatherByLocation( + location: string, +): Promise<WeatherToolOutput> { const { latitude, longitude } = await getGeoLocation(location); const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone; - const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`; + + const params = new URLSearchParams({ + latitude: latitude.toString(), + longitude: longitude.toString(), + current: "temperature_2m,weather_code", + hourly: "temperature_2m,weather_code", + daily: "weather_code", + timezone, + }); + + const apiUrl = `https://api.open-meteo.com/v1/forecast?${params}`; + const response = await fetch(apiUrl); - const data = (await response.json()) as WeatherToolOutput; - return data; + if (!response.ok) { + throw new Error(`Weather API request failed: ${response.statusText}`); + } + + return (await response.json()) as WeatherToolOutput; } async function getGeoLocation( location: string, ): Promise<{ latitude: number; longitude: number }> { - const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`; + const params = new URLSearchParams({ + name: location, + count: "10", + language: "en", + format: "json", + }); + + const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?${params}`; + const response = await fetch(apiUrl); + if (!response.ok) { + throw new Error(`Geocoding API request failed: ${response.statusText}`); + } + const data = await response.json(); + if (!data.results?.length) { + throw new Error(`No location found for: ${location}`); + } + const { latitude, longitude } = data.results[0]; return { latitude, longitude }; } diff --git a/packages/tools/src/tools/wiki.ts b/packages/tools/src/tools/wiki.ts index b07b7283884543f7efa6e718c10456d3b8638fb2..e5eff4718e2d654d5cde1d0ec44ab54c329ac551 100644 --- a/packages/tools/src/tools/wiki.ts +++ b/packages/tools/src/tools/wiki.ts @@ -18,12 +18,10 @@ export const wiki = () => { execute: async ({ query, lang }): Promise<WikiToolOutput> => { wikipedia.setLang(lang); const searchResult = await wikipedia.search(query); - const pageTitle = searchResult.results[0].title; + const pageTitle = searchResult?.results[0]?.title; if (!pageTitle) return { title: "No search results.", content: "" }; - const pageResult = await wikipedia.page(pageTitle, { - autoSuggest: false, - }); - return { title: pageTitle, content: await pageResult.content() }; + const result = await wikipedia.page(pageTitle, { autoSuggest: false }); + return { title: pageTitle, content: await result.content() }; }, }); }; diff --git a/packages/tools/tests/code-generator.test.ts b/packages/tools/tests/code-generator.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..80fe5f52e6b7b2f2655caceb5d7dbbd20a5ca4a3 --- /dev/null +++ b/packages/tools/tests/code-generator.test.ts @@ -0,0 +1,37 @@ +import { Settings } from "@llamaindex/core/global"; +import { MockLLM } from "@llamaindex/core/utils"; +import { describe, expect, test } from "vitest"; +import { + codeGenerator, + type CodeGeneratorToolOutput, +} from "../src/tools/code-generator"; + +Settings.llm = new MockLLM({ + responseMessage: `{ + "commentary": "Creating a simple Next.js page with a hello world message", + "template": "nextjs-developer", + "title": "Hello World App", + "description": "A simple Next.js hello world application", + "additional_dependencies": [], + "has_additional_dependencies": false, + "install_dependencies_command": "", + "port": 3000, + "file_path": "pages/index.tsx", + "code": "export default function Home() { return <h1>Hello World</h1> }" + }`, +}); + +describe("Code Generator Tool", () => { + test("generates Next.js application code", async () => { + const generator = codeGenerator(); + const result = (await generator.call({ + requirement: "Create a simple Next.js hello world page", + })) as CodeGeneratorToolOutput; + + expect(result.isError).toBe(false); + expect(result.artifact).toBeDefined(); + expect(result.artifact?.template).toBe("nextjs-developer"); + expect(result.artifact?.file_path).toBe("pages/index.tsx"); + expect(result.artifact?.code).toContain("export default"); + }); +}); diff --git a/packages/tools/tests/document-generator.test.ts b/packages/tools/tests/document-generator.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..94b5e0e247b8e65ab848b699a439bd73c06847e1 --- /dev/null +++ b/packages/tools/tests/document-generator.test.ts @@ -0,0 +1,24 @@ +import path from "path"; +import { describe, expect, test, vi } from "vitest"; +import { documentGenerator } from "../src/tools/document-generator"; + +// Mock the helper functions +vi.mock("../src/helper", () => ({ + saveDocument: vi.fn().mockResolvedValue(undefined), + getFileUrl: vi.fn().mockReturnValue("http://example.com/test-doc.html"), +})); + +describe("Document Generator Tool", () => { + test("converts markdown to html document", async () => { + const docGen = documentGenerator({ + outputDir: path.join(__dirname, "output"), + }); + + const result = await docGen.call({ + originalContent: "# Hello World\nThis is a test", + fileName: "test-doc", + }); + + expect(result).toBe("URL: http://example.com/test-doc.html"); + }); +}); diff --git a/packages/tools/tests/duckduckgo.test.ts b/packages/tools/tests/duckduckgo.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..14e3c67a4d4179ff8f76d61c6f1ea8325a3153af --- /dev/null +++ b/packages/tools/tests/duckduckgo.test.ts @@ -0,0 +1,19 @@ +import { describe, expect, test } from "vitest"; +import { duckduckgo, type DuckDuckGoToolOutput } from "../src/tools/duckduckgo"; + +describe("DuckDuckGo Tool", () => { + test("performs search and returns results", async () => { + const searchTool = duckduckgo(); + const results = (await searchTool.call({ + query: "OpenAI ChatGPT", + maxResults: 3, + })) as DuckDuckGoToolOutput; + + expect(Array.isArray(results)).toBe(true); + expect(results.length).toBeLessThanOrEqual(3); + const firstResult = results[0]; + expect(firstResult).toHaveProperty("title"); + expect(firstResult).toHaveProperty("description"); + expect(firstResult).toHaveProperty("url"); + }); +}); diff --git a/packages/tools/tests/form-filling.test.ts b/packages/tools/tests/form-filling.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..c6dfcf986809ddaa8bf3486cd041f066e699e1ae --- /dev/null +++ b/packages/tools/tests/form-filling.test.ts @@ -0,0 +1,44 @@ +import fs from "fs"; +import path from "path"; +import { describe, expect, test, vi } from "vitest"; +import { fillMissingCells } from "../src/tools/form-filling"; + +vi.mock("fs", () => ({ + default: { + readFile: vi.fn(), + promises: { + readFile: vi.fn(), + }, + }, +})); + +vi.mock("../src/helper", () => ({ + saveDocument: vi.fn(), + getFileUrl: vi.fn().mockReturnValue("http://example.com/filled.csv"), +})); + +describe("Form Filling Tools", () => { + test("fillMissingCells fills cells with provided answers", async () => { + // Mock CSV content + const mockCsvContent = "Name,Age,City\nJohn,,Paris\nMary,,"; + vi.mocked(fs.promises.readFile).mockResolvedValue(mockCsvContent); + + const filler = fillMissingCells({ + outputDir: path.join(__dirname, "output"), + }); + + const result = await filler.call({ + filePath: "test.csv", + cells: [ + { + rowIndex: 1, + columnIndex: 1, + answer: "25", + }, + ], + }); + + expect(result).toContain("Successfully filled missing cells"); + expect(result).toContain("http://example.com/filled.csv"); + }); +}); diff --git a/packages/tools/tests/img-gen.test.ts b/packages/tools/tests/img-gen.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..0a74743b30367f9be7ec5bd2fc7333b2fb903b83 --- /dev/null +++ b/packages/tools/tests/img-gen.test.ts @@ -0,0 +1,31 @@ +import { describe, expect, test, vi } from "vitest"; +import { + imageGenerator, + type ImgGeneratorToolOutput, +} from "../src/tools/img-gen"; + +vi.mock("got", () => ({ + default: { + post: vi.fn().mockReturnValue({ + buffer: vi.fn().mockResolvedValue(Buffer.from("mock-image-data")), + }), + }, +})); + +describe("Image Generator Tool", () => { + test("generates image from prompt", async () => { + const imgTool = imageGenerator({ + apiKey: "mock-stability-key", + outputDir: "output", + fileServerURLPrefix: "http://localhost:3000", + }); + + const result = (await imgTool.call({ + prompt: "a cute cat playing with yarn", + })) as ImgGeneratorToolOutput; + + expect(result.isSuccess).toBe(true); + expect(result.imageUrl).toBeDefined(); + expect(result.errorMessage).toBeUndefined(); + }); +}); diff --git a/packages/tools/tests/interpreter.test.ts b/packages/tools/tests/interpreter.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..7fbe4fd266cb778256063616e505b6e55b52dac5 --- /dev/null +++ b/packages/tools/tests/interpreter.test.ts @@ -0,0 +1,50 @@ +import { Sandbox } from "@e2b/code-interpreter"; +import path from "path"; +import { describe, expect, test, vi } from "vitest"; +import { + interpreter, + type InterpreterToolOutput, +} from "../src/tools/interpreter"; + +vi.mock("@e2b/code-interpreter", () => ({ + Sandbox: { + create: vi.fn().mockImplementation(() => ({ + runCode: vi.fn().mockResolvedValue({ + error: null, + logs: { + stdout: ["Hello, World!", "x = 2"], + stderr: [], + }, + text: "Hello, World!\nx = 2", + results: [ + { + formats: () => [], + }, + ], + }), + files: { + write: vi.fn().mockResolvedValue(undefined), + }, + })), + }, +})); + +describe("Code Interpreter Tool", () => { + test("executes simple python code", async () => { + const interpreterTool = interpreter({ + apiKey: "mock-api-key", + outputDir: path.join(__dirname, "output"), + uploadedFilesDir: path.join(__dirname, "files"), + }); + + const result = (await interpreterTool.call({ + code: "print('Hello, World!')\nx = 1 + 1\nprint(f'x = {x}')", + })) as InterpreterToolOutput; + + expect(Sandbox.create).toHaveBeenCalledWith({ apiKey: "mock-api-key" }); + expect(result.isError).toBe(false); + expect(result.logs.stdout).toEqual(["Hello, World!", "x = 2"]); + expect(result.retryCount).toBe(1); + expect(result.extraResult).toEqual([]); + }); +}); diff --git a/packages/tools/tests/openapi-action.test.ts b/packages/tools/tests/openapi-action.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..5f2d2a86b75072654b8882d38ea1360754326485 --- /dev/null +++ b/packages/tools/tests/openapi-action.test.ts @@ -0,0 +1,70 @@ +import SwaggerParser from "@apidevtools/swagger-parser"; +import got from "got"; +import { describe, expect, test, vi } from "vitest"; +import { OpenAPIActionTool } from "../src/tools/openapi-action"; + +// Mock SwaggerParser and got +vi.mock("@apidevtools/swagger-parser", () => ({ + default: { + validate: vi.fn(), + }, +})); + +vi.mock("got", () => ({ + default: { + get: vi.fn(), + post: vi.fn(), + patch: vi.fn(), + }, +})); + +describe("OpenAPI Action Tool", () => { + test("loads and executes API requests", async () => { + // Mock swagger spec + vi.mocked(SwaggerParser.validate).mockResolvedValue({ + openapi: "3.0.0", + info: { + title: "Test API", + description: "Test API Description", + version: "1.0.0", + }, + paths: { + "/test": { + get: { + description: "Test endpoint", + responses: { + "200": { + description: "Successful response", + }, + }, + }, + }, + }, + }); + + // Mock API response + vi.mocked(got.get).mockReturnValue({ + json: () => Promise.resolve({ data: "test response" }), + } as ReturnType<typeof got.get>); + + const tool = new OpenAPIActionTool("https://api.test.com/openapi.json"); + const tools = await tool.toToolFunctions(); + + // Verify tools were created + expect(tools).toHaveLength(4); // load_spec, get, post, patch + + // Test GET request + const result = await tool.getRequest({ + url: "https://api.test.com/test", + params: { key: "value" }, + }); + + expect(got.get).toHaveBeenCalledWith( + "https://api.test.com/test", + expect.objectContaining({ + searchParams: { key: "value" }, + }), + ); + expect(result).toEqual({ data: "test response" }); + }); +}); diff --git a/packages/tools/tests/weather.test.ts b/packages/tools/tests/weather.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..c57691ffd27c6031ad4d14a0c5c6d7d4d265ca91 --- /dev/null +++ b/packages/tools/tests/weather.test.ts @@ -0,0 +1,15 @@ +import { describe, expect, test } from "vitest"; +import { weather } from "../src/tools/weather"; + +describe("Weather Tool", () => { + test("weather tool returns data for valid location", async () => { + const weatherTool = weather(); + const result = await weatherTool.call({ + location: "London", + }); + + expect(result).toHaveProperty("current"); + expect(result).toHaveProperty("hourly"); + expect(result).toHaveProperty("daily"); + }); +}); diff --git a/packages/tools/tests/wiki.test.ts b/packages/tools/tests/wiki.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..0ac22913680432b6014d0e166833ec994c929441 --- /dev/null +++ b/packages/tools/tests/wiki.test.ts @@ -0,0 +1,15 @@ +import { describe, expect, test } from "vitest"; +import { wiki } from "../src/tools/wiki"; + +describe("Wikipedia Tool", () => { + test("wiki tool returns content for valid query", async () => { + const wikipediaTool = wiki(); + const result = await wikipediaTool.call({ + query: "Albert Einstein", + lang: "en", + }); + + expect(result).toHaveProperty("title"); + expect(result).toHaveProperty("content"); + }); +}); diff --git a/packages/tools/tsconfig.json b/packages/tools/tsconfig.json index a93775d954ab510bbae6d3aaabe2c7204f557e2b..e27177e4846105c0c17eee66119ed8a35206ff82 100644 --- a/packages/tools/tsconfig.json +++ b/packages/tools/tsconfig.json @@ -10,6 +10,6 @@ "strict": true, "types": ["node"] }, - "include": ["./src"], + "include": ["./src", "tests"], "exclude": ["node_modules"] }