Skip to content
Snippets Groups Projects
Unverified Commit 04ddebcd authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

feat: Add publisher agent, merge code with streaming template (#324)


---------
Co-authored-by: default avatarMarcus Schiesser <mail@marcusschiesser.de>
parent 3e8057a8
No related branches found
No related tags found
No related merge requests found
Showing
with 570 additions and 65 deletions
---
"create-llama": patch
---
Add publisher agent to multi-agents for generating documents (PDF and HTML)
---
"create-llama": patch
---
Allow tool selection for multi-agents (Python and TS)
...@@ -33,6 +33,7 @@ if ( ...@@ -33,6 +33,7 @@ if (
const toolOptions = [ const toolOptions = [
"wikipedia.WikipediaToolSpec", "wikipedia.WikipediaToolSpec",
"google.GoogleSearchToolSpec", "google.GoogleSearchToolSpec",
"document_generator",
]; ];
const dataSources = [ const dataSources = [
......
...@@ -109,7 +109,9 @@ export async function runCreateLlama({ ...@@ -109,7 +109,9 @@ export async function runCreateLlama({
if (appType) { if (appType) {
commandArgs.push(appType); commandArgs.push(appType);
} }
if (!useLlamaParse) { if (useLlamaParse) {
commandArgs.push("--use-llama-parse");
} else {
commandArgs.push("--no-llama-parse"); commandArgs.push("--no-llama-parse");
} }
......
...@@ -426,34 +426,35 @@ const getToolEnvs = (tools?: Tool[]): EnvVar[] => { ...@@ -426,34 +426,35 @@ const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
const getSystemPromptEnv = ( const getSystemPromptEnv = (
tools?: Tool[], tools?: Tool[],
dataSources?: TemplateDataSource[], dataSources?: TemplateDataSource[],
framework?: TemplateFramework, template?: TemplateType,
): EnvVar[] => { ): EnvVar[] => {
const defaultSystemPrompt = const defaultSystemPrompt =
"You are a helpful assistant who helps users with their questions."; "You are a helpful assistant who helps users with their questions.";
const systemPromptEnv: EnvVar[] = [];
// build tool system prompt by merging all tool system prompts // build tool system prompt by merging all tool system prompts
let toolSystemPrompt = ""; // multiagent template doesn't need system prompt
tools?.forEach((tool) => { if (template !== "multiagent") {
const toolSystemPromptEnv = tool.envVars?.find( let toolSystemPrompt = "";
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR, tools?.forEach((tool) => {
); const toolSystemPromptEnv = tool.envVars?.find(
if (toolSystemPromptEnv) { (env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
toolSystemPrompt += toolSystemPromptEnv.value + "\n"; );
} if (toolSystemPromptEnv) {
}); toolSystemPrompt += toolSystemPromptEnv.value + "\n";
}
});
const systemPrompt = toolSystemPrompt const systemPrompt = toolSystemPrompt
? `\"${toolSystemPrompt}\"` ? `\"${toolSystemPrompt}\"`
: defaultSystemPrompt; : defaultSystemPrompt;
const systemPromptEnv = [ systemPromptEnv.push({
{
name: "SYSTEM_PROMPT", name: "SYSTEM_PROMPT",
description: "The system prompt for the AI model.", description: "The system prompt for the AI model.",
value: systemPrompt, value: systemPrompt,
}, });
]; }
if (tools?.length == 0 && (dataSources?.length ?? 0 > 0)) { if (tools?.length == 0 && (dataSources?.length ?? 0 > 0)) {
const citationPrompt = `'You have provided information from a knowledge base that has been passed to you in nodes of information. const citationPrompt = `'You have provided information from a knowledge base that has been passed to you in nodes of information.
Each node has useful metadata such as node ID, file name, page, etc. Each node has useful metadata such as node ID, file name, page, etc.
...@@ -559,7 +560,7 @@ export const createBackendEnvFile = async ( ...@@ -559,7 +560,7 @@ export const createBackendEnvFile = async (
...getToolEnvs(opts.tools), ...getToolEnvs(opts.tools),
...getTemplateEnvs(opts.template), ...getTemplateEnvs(opts.template),
...getObservabilityEnvs(opts.observability), ...getObservabilityEnvs(opts.observability),
...getSystemPromptEnv(opts.tools, opts.dataSources, opts.framework), ...getSystemPromptEnv(opts.tools, opts.dataSources, opts.template),
]; ];
// Render and write env file // Render and write env file
const content = renderEnvVar(envVars); const content = renderEnvVar(envVars);
......
...@@ -364,7 +364,12 @@ export const installPythonTemplate = async ({ ...@@ -364,7 +364,12 @@ export const installPythonTemplate = async ({
| "modelConfig" | "modelConfig"
>) => { >) => {
console.log("\nInitializing Python project with template:", template, "\n"); console.log("\nInitializing Python project with template:", template, "\n");
const templatePath = path.join(templatesDir, "types", template, framework); let templatePath;
if (template === "extractor") {
templatePath = path.join(templatesDir, "types", "extractor", framework);
} else {
templatePath = path.join(templatesDir, "types", "streaming", framework);
}
await copy("**", root, { await copy("**", root, {
parents: true, parents: true,
cwd: templatePath, cwd: templatePath,
...@@ -401,23 +406,42 @@ export const installPythonTemplate = async ({ ...@@ -401,23 +406,42 @@ export const installPythonTemplate = async ({
cwd: path.join(compPath, "services", "python"), cwd: path.join(compPath, "services", "python"),
}); });
} }
// Copy engine code
if (template === "streaming") { if (template === "streaming" || template === "multiagent") {
// For the streaming template only:
// Select and copy engine code based on data sources and tools // Select and copy engine code based on data sources and tools
let engine; let engine;
if (dataSources.length > 0 && (!tools || tools.length === 0)) { // Multiagent always uses agent engine
console.log("\nNo tools selected - use optimized context chat engine\n"); if (template === "multiagent") {
engine = "chat";
} else {
engine = "agent"; engine = "agent";
} else {
// For streaming, use chat engine by default
// Unless tools are selected, in which case use agent engine
if (dataSources.length > 0 && (!tools || tools.length === 0)) {
console.log(
"\nNo tools selected - use optimized context chat engine\n",
);
engine = "chat";
} else {
engine = "agent";
}
} }
// Copy engine code
await copy("**", enginePath, { await copy("**", enginePath, {
parents: true, parents: true,
cwd: path.join(compPath, "engines", "python", engine), cwd: path.join(compPath, "engines", "python", engine),
}); });
} }
if (template === "multiagent") {
// Copy multi-agent code
await copy("**", path.join(root), {
parents: true,
cwd: path.join(compPath, "multiagent", "python"),
rename: assetRelocator,
});
}
console.log("Adding additional dependencies"); console.log("Adding additional dependencies");
const addOnDependencies = getAdditionalDependencies( const addOnDependencies = getAdditionalDependencies(
......
...@@ -110,6 +110,29 @@ For better results, you can specify the region parameter to get results from a s ...@@ -110,6 +110,29 @@ For better results, you can specify the region parameter to get results from a s
}, },
], ],
}, },
{
display: "Document generator",
name: "document_generator",
supportedFrameworks: ["fastapi", "nextjs", "express"],
dependencies: [
{
name: "xhtml2pdf",
version: "^0.2.14",
},
{
name: "markdown",
version: "^3.7",
},
],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for document generator tool.",
value: `If user request for a report or a post, use document generator tool to create a file and reply with the link to the file.`,
},
],
},
{ {
display: "Code Interpreter", display: "Code Interpreter",
name: "interpreter", name: "interpreter",
......
...@@ -157,7 +157,10 @@ export const installTSTemplate = async ({ ...@@ -157,7 +157,10 @@ export const installTSTemplate = async ({
// Select and copy engine code based on data sources and tools // Select and copy engine code based on data sources and tools
let engine; let engine;
tools = tools ?? []; tools = tools ?? [];
if (dataSources.length > 0 && tools.length === 0) { // multiagent template always uses agent engine
if (template === "multiagent") {
engine = "agent";
} else if (dataSources.length > 0 && tools.length === 0) {
console.log("\nNo tools selected - use optimized context chat engine\n"); console.log("\nNo tools selected - use optimized context chat engine\n");
engine = "chat"; engine = "chat";
} else { } else {
......
...@@ -141,12 +141,10 @@ export const getDataSourceChoices = ( ...@@ -141,12 +141,10 @@ export const getDataSourceChoices = (
}); });
} }
if (selectedDataSource === undefined || selectedDataSource.length === 0) { if (selectedDataSource === undefined || selectedDataSource.length === 0) {
if (template !== "multiagent") { choices.push({
choices.push({ title: "No datasource",
title: "No datasource", value: "none",
value: "none", });
});
}
choices.push({ choices.push({
title: title:
process.platform !== "linux" process.platform !== "linux"
...@@ -734,8 +732,10 @@ export const askQuestions = async ( ...@@ -734,8 +732,10 @@ export const askQuestions = async (
} }
} }
if (!program.tools && program.template === "streaming") { if (
// TODO: allow to select tools also for multi-agent framework !program.tools &&
(program.template === "streaming" || program.template === "multiagent")
) {
if (ciInfo.isCI) { if (ciInfo.isCI) {
program.tools = getPrefOrDefault("tools"); program.tools = getPrefOrDefault("tools");
} else { } else {
......
...@@ -8,7 +8,7 @@ from llama_index.core.settings import Settings ...@@ -8,7 +8,7 @@ from llama_index.core.settings import Settings
from llama_index.core.tools.query_engine import QueryEngineTool from llama_index.core.tools.query_engine import QueryEngineTool
def get_chat_engine(filters=None, params=None, event_handlers=None): def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", 0)) top_k = int(os.getenv("TOP_K", 0))
tools = [] tools = []
......
import importlib
import os import os
import yaml import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec
class ToolType: class ToolType:
...@@ -40,14 +41,26 @@ class ToolFactory: ...@@ -40,14 +41,26 @@ class ToolFactory:
raise ValueError(f"Failed to load tool {tool_name}: {e}") raise ValueError(f"Failed to load tool {tool_name}: {e}")
@staticmethod @staticmethod
def from_env() -> list[FunctionTool]: def from_env(
tools = [] map_result: bool = False,
) -> list[FunctionTool] | dict[str, FunctionTool]:
"""
Load tools from the configured file.
Params:
- use_map: if True, return map of tool name and the tool itself
"""
if map_result:
tools = {}
else:
tools = []
if os.path.exists("config/tools.yaml"): if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f: with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f) tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items(): for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items(): for tool_name, config in config_entries.items():
tools.extend( tool = ToolFactory.load_tools(tool_type, tool_name, config)
ToolFactory.load_tools(tool_type, tool_name, config) if map_result:
) tools[tool_name] = tool
else:
tools.extend(tool)
return tools return tools
import logging
import os
import re
from enum import Enum
from io import BytesIO
from llama_index.core.tools.function_tool import FunctionTool
OUTPUT_DIR = "output/tools"
class DocumentType(Enum):
PDF = "pdf"
HTML = "html"
COMMON_STYLES = """
body {
font-family: Arial, sans-serif;
line-height: 1.3;
color: #333;
}
h1, h2, h3, h4, h5, h6 {
margin-top: 1em;
margin-bottom: 0.5em;
}
p {
margin-bottom: 0.7em;
}
code {
background-color: #f4f4f4;
padding: 2px 4px;
border-radius: 4px;
}
pre {
background-color: #f4f4f4;
padding: 10px;
border-radius: 4px;
overflow-x: auto;
}
table {
border-collapse: collapse;
width: 100%;
margin-bottom: 1em;
}
th, td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
th {
background-color: #f2f2f2;
font-weight: bold;
}
"""
HTML_SPECIFIC_STYLES = """
body {
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
"""
PDF_SPECIFIC_STYLES = """
@page {
size: letter;
margin: 2cm;
}
body {
font-size: 11pt;
}
h1 { font-size: 18pt; }
h2 { font-size: 16pt; }
h3 { font-size: 14pt; }
h4, h5, h6 { font-size: 12pt; }
pre, code {
font-family: Courier, monospace;
font-size: 0.9em;
}
"""
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{common_styles}
{specific_styles}
</style>
</head>
<body>
{content}
</body>
</html>
"""
class DocumentGenerator:
@classmethod
def _generate_html_content(cls, original_content: str) -> str:
"""
Generate HTML content from the original markdown content.
"""
try:
import markdown
except ImportError:
raise ImportError(
"Failed to import required modules. Please install markdown."
)
# Convert markdown to HTML with fenced code and table extensions
html_content = markdown.markdown(
original_content, extensions=["fenced_code", "tables"]
)
return html_content
@classmethod
def _generate_pdf(cls, html_content: str) -> BytesIO:
"""
Generate a PDF from the HTML content.
"""
try:
from xhtml2pdf import pisa
except ImportError:
raise ImportError(
"Failed to import required modules. Please install xhtml2pdf."
)
pdf_html = HTML_TEMPLATE.format(
common_styles=COMMON_STYLES,
specific_styles=PDF_SPECIFIC_STYLES,
content=html_content,
)
buffer = BytesIO()
pdf = pisa.pisaDocument(
BytesIO(pdf_html.encode("UTF-8")), buffer, encoding="UTF-8"
)
if pdf.err:
logging.error(f"PDF generation failed: {pdf.err}")
raise ValueError("PDF generation failed")
buffer.seek(0)
return buffer
@classmethod
def _generate_html(cls, html_content: str) -> str:
"""
Generate a complete HTML document with the given HTML content.
"""
return HTML_TEMPLATE.format(
common_styles=COMMON_STYLES,
specific_styles=HTML_SPECIFIC_STYLES,
content=html_content,
)
@classmethod
def generate_document(
cls, original_content: str, document_type: str, file_name: str
) -> str:
"""
To generate document as PDF or HTML file.
Parameters:
original_content: str (markdown style)
document_type: str (pdf or html) specify the type of the file format based on the use case
file_name: str (name of the document file) must be a valid file name, no extensions needed
Returns:
str (URL to the document file): A file URL ready to serve.
"""
try:
document_type = DocumentType(document_type.lower())
except ValueError:
raise ValueError(
f"Invalid document type: {document_type}. Must be 'pdf' or 'html'."
)
# Always generate html content first
html_content = cls._generate_html_content(original_content)
# Based on the type of document, generate the corresponding file
if document_type == DocumentType.PDF:
content = cls._generate_pdf(html_content)
file_extension = "pdf"
elif document_type == DocumentType.HTML:
content = BytesIO(cls._generate_html(html_content).encode("utf-8"))
file_extension = "html"
else:
raise ValueError(f"Unexpected document type: {document_type}")
file_name = cls._validate_file_name(file_name)
file_path = os.path.join(OUTPUT_DIR, f"{file_name}.{file_extension}")
cls._write_to_file(content, file_path)
file_url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{file_path}"
return file_url
@staticmethod
def _write_to_file(content: BytesIO, file_path: str):
"""
Write the content to a file.
"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as file:
file.write(content.getvalue())
except Exception as e:
raise e
@staticmethod
def _validate_file_name(file_name: str) -> str:
"""
Validate the file name.
"""
# Don't allow directory traversal
if os.path.isabs(file_name):
raise ValueError("File name is not allowed.")
# Don't allow special characters
if re.match(r"^[a-zA-Z0-9_.-]+$", file_name):
return file_name
else:
raise ValueError("File name is not allowed to contain special characters.")
def get_tools(**kwargs):
return [FunctionTool.from_defaults(DocumentGenerator.generate_document)]
...@@ -32,5 +32,37 @@ def duckduckgo_search( ...@@ -32,5 +32,37 @@ def duckduckgo_search(
return results return results
def duckduckgo_image_search(
query: str,
region: str = "wt-wt",
max_results: int = 10,
):
"""
Use this function to search for images in DuckDuckGo.
Args:
query (str): The query to search in DuckDuckGo.
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
max_results Optional(int): The maximum number of results to be returned. Default is 10.
"""
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"duckduckgo_search package is required to use this function."
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
)
params = {
"keywords": query,
"region": region,
"max_results": max_results,
}
with DDGS() as ddg:
results = list(ddg.images(**params))
return results
def get_tools(**kwargs): def get_tools(**kwargs):
return [FunctionTool.from_defaults(duckduckgo_search)] return [
FunctionTool.from_defaults(duckduckgo_search),
FunctionTool.from_defaults(duckduckgo_image_search),
]
import logging
import os import os
import uuid import uuid
import logging
import requests
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field
import requests
from llama_index.core.tools import FunctionTool from llama_index.core.tools import FunctionTool
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,7 +27,7 @@ class ImageGeneratorToolOutput(BaseModel): ...@@ -26,7 +27,7 @@ class ImageGeneratorToolOutput(BaseModel):
class ImageGeneratorTool: class ImageGeneratorTool:
_IMG_OUTPUT_FORMAT = "webp" _IMG_OUTPUT_FORMAT = "webp"
_IMG_OUTPUT_DIR = "output/tool" _IMG_OUTPUT_DIR = "output/tools"
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core" _IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
def __init__(self, api_key: str = None): def __init__(self, api_key: str = None):
......
import os
import logging
import base64 import base64
import logging
import os
import uuid import uuid
from pydantic import BaseModel from typing import Dict, List, Optional
from typing import List, Dict, Optional
from llama_index.core.tools import FunctionTool
from e2b_code_interpreter import CodeInterpreter from e2b_code_interpreter import CodeInterpreter
from e2b_code_interpreter.models import Logs from e2b_code_interpreter.models import Logs
from llama_index.core.tools import FunctionTool
from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,7 +26,7 @@ class E2BToolOutput(BaseModel): ...@@ -26,7 +26,7 @@ class E2BToolOutput(BaseModel):
class E2BCodeInterpreter: class E2BCodeInterpreter:
output_dir = "output/tool" output_dir = "output/tools"
def __init__(self, api_key: str = None): def __init__(self, api_key: str = None):
if api_key is None: if api_key is None:
......
...@@ -9,7 +9,7 @@ from llama_index.core.memory import ChatMemoryBuffer ...@@ -9,7 +9,7 @@ from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
def get_chat_engine(filters=None, params=None, event_handlers=None): def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None)
top_k = int(os.getenv("TOP_K", 0)) top_k = int(os.getenv("TOP_K", 0))
......
import { JSONSchemaType } from "ajv";
import { BaseTool, ToolMetadata } from "llamaindex";
import { marked } from "marked";
import path from "node:path";
import { saveDocument } from "../../llamaindex/documents/helper";
const OUTPUT_DIR = "output/tools";
type DocumentParameter = {
originalContent: string;
fileName: string;
};
const DEFAULT_METADATA: ToolMetadata<JSONSchemaType<DocumentParameter>> = {
name: "document_generator",
description:
"Generate HTML document from markdown content. Return a file url to the document",
parameters: {
type: "object",
properties: {
originalContent: {
type: "string",
description: "The original markdown content to convert.",
},
fileName: {
type: "string",
description: "The name of the document file (without extension).",
},
},
required: ["originalContent", "fileName"],
},
};
const COMMON_STYLES = `
body {
font-family: Arial, sans-serif;
line-height: 1.3;
color: #333;
}
h1, h2, h3, h4, h5, h6 {
margin-top: 1em;
margin-bottom: 0.5em;
}
p {
margin-bottom: 0.7em;
}
code {
background-color: #f4f4f4;
padding: 2px 4px;
border-radius: 4px;
}
pre {
background-color: #f4f4f4;
padding: 10px;
border-radius: 4px;
overflow-x: auto;
}
table {
border-collapse: collapse;
width: 100%;
margin-bottom: 1em;
}
th, td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
th {
background-color: #f2f2f2;
font-weight: bold;
}
img {
max-width: 90%;
height: auto;
display: block;
margin: 1em auto;
border-radius: 10px;
}
`;
const HTML_SPECIFIC_STYLES = `
body {
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
`;
const HTML_TEMPLATE = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
${COMMON_STYLES}
${HTML_SPECIFIC_STYLES}
</style>
</head>
<body>
{{content}}
</body>
</html>
`;
export interface DocumentGeneratorParams {
metadata?: ToolMetadata<JSONSchemaType<DocumentParameter>>;
}
export class DocumentGenerator implements BaseTool<DocumentParameter> {
metadata: ToolMetadata<JSONSchemaType<DocumentParameter>>;
constructor(params: DocumentGeneratorParams) {
this.metadata = params.metadata ?? DEFAULT_METADATA;
}
private static async generateHtmlContent(
originalContent: string,
): Promise<string> {
return await marked(originalContent);
}
private static generateHtmlDocument(htmlContent: string): string {
return HTML_TEMPLATE.replace("{{content}}", htmlContent);
}
async call(input: DocumentParameter): Promise<string> {
const { originalContent, fileName } = input;
const htmlContent =
await DocumentGenerator.generateHtmlContent(originalContent);
const fileContent = DocumentGenerator.generateHtmlDocument(htmlContent);
const filePath = path.join(OUTPUT_DIR, `${fileName}.html`);
return `URL: ${await saveDocument(filePath, fileContent)}`;
}
}
export function getTools(): BaseTool[] {
return [new DocumentGenerator({})];
}
...@@ -5,15 +5,19 @@ import { BaseTool, ToolMetadata } from "llamaindex"; ...@@ -5,15 +5,19 @@ import { BaseTool, ToolMetadata } from "llamaindex";
export type DuckDuckGoParameter = { export type DuckDuckGoParameter = {
query: string; query: string;
region?: string; region?: string;
maxResults?: number;
}; };
export type DuckDuckGoToolParams = { export type DuckDuckGoToolParams = {
metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>; metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
}; };
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = { const DEFAULT_SEARCH_METADATA: ToolMetadata<
name: "duckduckgo", JSONSchemaType<DuckDuckGoParameter>
description: "Use this function to search for any query in DuckDuckGo.", > = {
name: "duckduckgo_search",
description:
"Use this function to search for information (only text) in the internet using DuckDuckGo.",
parameters: { parameters: {
type: "object", type: "object",
properties: { properties: {
...@@ -27,6 +31,12 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = { ...@@ -27,6 +31,12 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = {
"Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...", "Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...",
nullable: true, nullable: true,
}, },
maxResults: {
type: "number",
description:
"Optional, The maximum number of results to be returned. Default is 10.",
nullable: true,
},
}, },
required: ["query"], required: ["query"],
}, },
...@@ -42,15 +52,18 @@ export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> { ...@@ -42,15 +52,18 @@ export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>; metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
constructor(params: DuckDuckGoToolParams) { constructor(params: DuckDuckGoToolParams) {
this.metadata = params.metadata ?? DEFAULT_META_DATA; this.metadata = params.metadata ?? DEFAULT_SEARCH_METADATA;
} }
async call(input: DuckDuckGoParameter) { async call(input: DuckDuckGoParameter) {
const { query, region } = input; const { query, region, maxResults = 10 } = input;
const options = region ? { region } : {}; const options = region ? { region } : {};
// Temporarily sleep to reduce overloading the DuckDuckGo
await new Promise((resolve) => setTimeout(resolve, 1000));
const searchResults = await search(query, options); const searchResults = await search(query, options);
return searchResults.results.map((result) => { return searchResults.results.slice(0, maxResults).map((result) => {
return { return {
title: result.title, title: result.title,
description: result.description, description: result.description,
...@@ -59,3 +72,7 @@ export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> { ...@@ -59,3 +72,7 @@ export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
}); });
} }
} }
export function getTools() {
return [new DuckDuckGoSearchTool({})];
}
...@@ -37,7 +37,7 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = { ...@@ -37,7 +37,7 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = {
export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> { export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> {
readonly IMG_OUTPUT_FORMAT = "webp"; readonly IMG_OUTPUT_FORMAT = "webp";
readonly IMG_OUTPUT_DIR = "output/tool"; readonly IMG_OUTPUT_DIR = "output/tools";
readonly IMG_GEN_API = readonly IMG_GEN_API =
"https://api.stability.ai/v2beta/stable-image/generate/core"; "https://api.stability.ai/v2beta/stable-image/generate/core";
......
import { BaseToolWithCall } from "llamaindex"; import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import {
DocumentGenerator,
DocumentGeneratorParams,
} from "./document-generator";
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo"; import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen"; import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen";
import { InterpreterTool, InterpreterToolParams } from "./interpreter"; import { InterpreterTool, InterpreterToolParams } from "./interpreter";
...@@ -43,6 +47,9 @@ const toolFactory: Record<string, ToolCreator> = { ...@@ -43,6 +47,9 @@ const toolFactory: Record<string, ToolCreator> = {
img_gen: async (config: unknown) => { img_gen: async (config: unknown) => {
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)]; return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
}, },
document_generator: async (config: unknown) => {
return [new DocumentGenerator(config as DocumentGeneratorParams)];
},
}; };
async function createLocalTools( async function createLocalTools(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment