From 904c53cf769981e220d65229d9e29ed4957d7669 Mon Sep 17 00:00:00 2001 From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com> Date: Tue, 6 Feb 2024 17:34:03 +0700 Subject: [PATCH] feat: Add support for llamahub tools (#517) Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- create-app.ts | 16 +++++- e2e/utils.ts | 2 + helpers/python.ts | 33 ++++++++++-- helpers/tools.ts | 32 ++++++++++++ helpers/types.ts | 1 + index.ts | 28 +++++++++++ questions.ts | 32 +++++++++++- .../engines/python/agent/__init__.py | 50 +++++++++++++++++++ .../components/engines/python/agent/tools.py | 33 ++++++++++++ .../engines/python/chat/__init__.py | 7 +++ .../vectordbs/python/mongo/index.py | 4 +- .../components/vectordbs/python/none/index.py | 10 ++-- .../components/vectordbs/python/pg/index.py | 4 +- .../simple/fastapi/app/api/routers/chat.py | 2 +- .../simple/fastapi/app/engine/__init__.py | 7 +++ .../types/simple/fastapi/app/engine/index.py | 7 --- templates/types/simple/fastapi/pyproject.toml | 2 + .../streaming/fastapi/app/api/routers/chat.py | 2 +- .../streaming/fastapi/app/engine/__init__.py | 7 +++ .../streaming/fastapi/app/engine/index.py | 7 --- .../types/streaming/fastapi/pyproject.toml | 2 + 21 files changed, 258 insertions(+), 30 deletions(-) create mode 100644 helpers/tools.ts create mode 100644 templates/components/engines/python/agent/__init__.py create mode 100644 templates/components/engines/python/agent/tools.py create mode 100644 templates/components/engines/python/chat/__init__.py delete mode 100644 templates/types/simple/fastapi/app/engine/index.py delete mode 100644 templates/types/streaming/fastapi/app/engine/index.py diff --git a/create-app.ts b/create-app.ts index 4c1eaf4d..835c02f1 100644 --- a/create-app.ts +++ b/create-app.ts @@ -1,6 +1,6 @@ /* eslint-disable import/no-extraneous-dependencies */ import path from "path"; -import { green } from "picocolors"; +import { green, yellow } from "picocolors"; import { tryGitInit } from "./helpers/git"; import { isFolderEmpty } from "./helpers/is-folder-empty"; import { getOnline } from "./helpers/is-online"; @@ -12,6 +12,7 @@ import terminalLink from "terminal-link"; import type { InstallTemplateArgs } from "./helpers"; import { installTemplate } from "./helpers"; import { templatesDir } from "./helpers/dir"; +import { toolsRequireConfig } from "./helpers/tools"; export type InstallAppArgs = Omit< InstallTemplateArgs, @@ -38,6 +39,7 @@ export async function createApp({ externalPort, postInstallAction, dataSource, + tools, }: InstallAppArgs): Promise<void> { const root = path.resolve(appPath); @@ -82,6 +84,7 @@ export async function createApp({ externalPort, postInstallAction, dataSource, + tools, }; if (frontend) { @@ -114,6 +117,17 @@ export async function createApp({ console.log(); } + if (toolsRequireConfig(tools)) { + console.log( + yellow( + `You have selected tools that require configuration. Please configure them in the ${terminalLink( + "tools_config.json", + `file://${root}/tools_config.json`, + )} file.`, + ), + ); + } + console.log(""); console.log(`${green("Success!")} Created ${appName} at ${appPath}`); console.log( diff --git a/e2e/utils.ts b/e2e/utils.ts index b06776da..2ec35e40 100644 --- a/e2e/utils.ts +++ b/e2e/utils.ts @@ -117,6 +117,8 @@ export async function runCreateLlama( externalPort, "--post-install-action", postInstallAction, + "--tools", + "none", ].join(" "); console.log(`running command '${command}' in ${cwd}`); let appProcess = exec(command, { diff --git a/helpers/python.ts b/helpers/python.ts index 8b2ec6d9..6fbc48b5 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -6,6 +6,7 @@ import terminalLink from "terminal-link"; import { copy } from "./copy"; import { templatesDir } from "./dir"; import { isPoetryAvailable, tryPoetryInstall } from "./poetry"; +import { getToolConfig } from "./tools"; import { InstallTemplateArgs, TemplateVectorDB } from "./types"; interface Dependency { @@ -128,6 +129,7 @@ export const installPythonTemplate = async ({ engine, vectorDb, dataSource, + tools, postInstallAction, }: Pick< InstallTemplateArgs, @@ -137,6 +139,7 @@ export const installPythonTemplate = async ({ | "engine" | "vectorDb" | "dataSource" + | "tools" | "postInstallAction" >) => { console.log("\nInitializing Python project with template:", template, "\n"); @@ -162,20 +165,44 @@ export const installPythonTemplate = async ({ }); if (engine === "context") { + const enginePath = path.join(root, "app", "engine"); const compPath = path.join(templatesDir, "components"); - let vectorDbDirName = vectorDb ?? "none"; + + const vectorDbDirName = vectorDb ?? "none"; const VectorDBPath = path.join( compPath, "vectordbs", "python", vectorDbDirName, ); - const enginePath = path.join(root, "app", "engine"); - await copy("**", path.join(root, "app", "engine"), { + await copy("**", enginePath, { parents: true, cwd: VectorDBPath, }); + // Copy engine code + if (tools !== undefined && tools.length > 0) { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "python", "agent"), + }); + // Write tools_config.json + const configContent: Record<string, any> = {}; + tools.forEach((tool) => { + configContent[tool] = getToolConfig(tool) ?? {}; + }); + const configFilePath = path.join(root, "tools_config.json"); + await fs.writeFile( + configFilePath, + JSON.stringify(configContent, null, 2), + ); + } else { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "python", "chat"), + }); + } + const dataSourceType = dataSource?.type; if (dataSourceType !== undefined && dataSourceType !== "none") { let loaderPath = diff --git a/helpers/tools.ts b/helpers/tools.ts new file mode 100644 index 00000000..13b8ff4b --- /dev/null +++ b/helpers/tools.ts @@ -0,0 +1,32 @@ +export type Tool = { + display: string; + name: string; + config?: Record<string, any>; +}; + +export const supportedTools: Tool[] = [ + { + display: "Google Search (configuration required)", + name: "google_search", + config: { + engine: "Your search engine id", + key: "Your search api key", + num: 2, + }, + }, + { + display: "Wikipedia", + name: "wikipedia", + }, +]; + +export const getToolConfig = (name: string) => { + return supportedTools.find((tool) => tool.name === name)?.config; +}; + +export const toolsRequireConfig = (tools?: string[]): boolean => { + if (tools) { + return tools.some((tool) => getToolConfig(tool)); + } + return false; +}; diff --git a/helpers/types.ts b/helpers/types.ts index 191e028f..5e4a9f6e 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -41,4 +41,5 @@ export interface InstallTemplateArgs { vectorDb?: TemplateVectorDB; externalPort?: number; postInstallAction?: TemplatePostInstallAction; + tools?: string[]; } diff --git a/index.ts b/index.ts index 601742c2..6fa19cc5 100644 --- a/index.ts +++ b/index.ts @@ -11,6 +11,7 @@ import { createApp } from "./create-app"; import { getPkgManager } from "./helpers/get-pkg-manager"; import { isFolderEmpty } from "./helpers/is-folder-empty"; import { runApp } from "./helpers/run-app"; +import { supportedTools } from "./helpers/tools"; import { validateNpmName } from "./helpers/validate-pkg"; import packageJson from "./package.json"; import { QuestionArgs, askQuestions, onPromptState } from "./questions"; @@ -146,6 +147,13 @@ const program = new Commander.Command(packageJson.name) ` Select which vector database you would like to use, such as 'none', 'pg' or 'mongo'. The default option is not to use a vector database and use the local filesystem instead ('none'). +`, + ) + .option( + "--tools <tools>", + ` + + Specify the tools you want to use by providing a comma-separated list. For example, 'google_search,wikipedia'. Use 'none' to not using any tools. `, ) .allowUnknownOption() @@ -156,6 +164,25 @@ if (process.argv.includes("--no-frontend")) { if (process.argv.includes("--no-eslint")) { program.eslint = false; } +if (process.argv.includes("--tools")) { + if (program.tools === "none") { + program.tools = []; + } else { + program.tools = program.tools.split(","); + // Check if tools are available + const toolsName = supportedTools.map((tool) => tool.name); + program.tools.forEach((tool: string) => { + if (!toolsName.includes(tool)) { + console.error( + `Error: Tool '${tool}' is not supported. Supported tools are: ${toolsName.join( + ", ", + )}`, + ); + process.exit(1); + } + }); + } +} const packageManager = !!program.useNpm ? "npm" @@ -256,6 +283,7 @@ async function run(): Promise<void> { externalPort: program.externalPort, postInstallAction: program.postInstallAction, dataSource: program.dataSource, + tools: program.tools, }); conf.set("preferences", preferences); diff --git a/questions.ts b/questions.ts index 0e671184..d6f29b8b 100644 --- a/questions.ts +++ b/questions.ts @@ -10,6 +10,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant"; import { templatesDir } from "./helpers/dir"; import { getAvailableLlamapackOptions } from "./helpers/llama-pack"; import { getRepoRootFolders } from "./helpers/repo"; +import { supportedTools, toolsRequireConfig } from "./helpers/tools"; export type QuestionArgs = Omit< InstallAppArgs, @@ -70,6 +71,7 @@ const defaults: QuestionArgs = { type: "none", config: {}, }, + tools: [], }; const handlers = { @@ -214,7 +216,12 @@ export const askQuestions = async ( const hasOpenAiKey = program.openAiKey || process.env["OPENAI_API_KEY"]; const hasVectorDb = program.vectorDb && program.vectorDb !== "none"; - if (!hasVectorDb && hasOpenAiKey) { + // Can run the app if all tools do not require configuration + if ( + !hasVectorDb && + hasOpenAiKey && + !toolsRequireConfig(program.tools) + ) { actionChoices.push({ title: "Generate code, install dependencies, and run the app (~2 min)", @@ -563,6 +570,29 @@ export const askQuestions = async ( } } + if ( + !program.tools && + program.framework === "fastapi" && + program.engine === "context" + ) { + if (ciInfo.isCI) { + program.tools = getPrefOrDefault("tools"); + } else { + const toolChoices = supportedTools.map((tool) => ({ + title: tool.display, + value: tool.name, + })); + const { tools } = await prompts({ + type: "multiselect", + name: "tools", + message: "Which tools would you like to use?", + choices: toolChoices, + }); + program.tools = tools; + preferences.tools = tools; + } + } + if (!program.openAiKey) { const { key } = await prompts( { diff --git a/templates/components/engines/python/agent/__init__.py b/templates/components/engines/python/agent/__init__.py new file mode 100644 index 00000000..f1b62b87 --- /dev/null +++ b/templates/components/engines/python/agent/__init__.py @@ -0,0 +1,50 @@ +import os + +from typing import Any, Optional +from llama_index.llms import LLM +from llama_index.agent import AgentRunner + +from app.engine.tools import ToolFactory +from app.engine.index import get_index +from llama_index.agent import ReActAgent +from llama_index.tools.query_engine import QueryEngineTool + + +def create_agent_from_llm( + llm: Optional[LLM] = None, + **kwargs: Any, +) -> AgentRunner: + from llama_index.agent import OpenAIAgent, ReActAgent + from llama_index.llms.openai import OpenAI + from llama_index.llms.openai_utils import is_function_calling_model + + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + return OpenAIAgent.from_tools( + llm=llm, + **kwargs, + ) + else: + return ReActAgent.from_tools( + llm=llm, + **kwargs, + ) + + +def get_chat_engine(): + tools = [] + + # Add query tool + index = get_index() + llm = index.service_context.llm + query_engine = index.as_query_engine(similarity_top_k=5) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + tools.append(query_engine_tool) + + # Add additional tools + tools += ToolFactory.from_env() + + return create_agent_from_llm( + llm=llm, + tools=tools, + verbose=True, + ) diff --git a/templates/components/engines/python/agent/tools.py b/templates/components/engines/python/agent/tools.py new file mode 100644 index 00000000..9fb9d488 --- /dev/null +++ b/templates/components/engines/python/agent/tools.py @@ -0,0 +1,33 @@ +import json +import importlib + +from llama_index.tools.tool_spec.base import BaseToolSpec +from llama_index.tools.function_tool import FunctionTool + + +class ToolFactory: + + @staticmethod + def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]: + try: + module_name = f"llama_hub.tools.{tool_name}.base" + module = importlib.import_module(module_name) + tool_cls_name = tool_name.title().replace("_", "") + "ToolSpec" + tool_class = getattr(module, tool_cls_name) + tool_spec: BaseToolSpec = tool_class(**kwargs) + return tool_spec.to_tool_list() + except (ImportError, AttributeError) as e: + raise ValueError(f"Unsupported tool: {tool_name}") from e + except TypeError as e: + raise ValueError( + f"Could not create tool: {tool_name}. With config: {kwargs}" + ) from e + + @staticmethod + def from_env() -> list[FunctionTool]: + tools = [] + with open("tools_config.json", "r") as f: + tool_configs = json.load(f) + for name, config in tool_configs.items(): + tools += ToolFactory.create_tool(name, **config) + return tools diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py new file mode 100644 index 00000000..18a6039b --- /dev/null +++ b/templates/components/engines/python/chat/__init__.py @@ -0,0 +1,7 @@ +from app.engine.index import get_index + + +def get_chat_engine(): + return get_index().as_chat_engine( + similarity_top_k=5, chat_mode="condense_plus_context" + ) diff --git a/templates/components/vectordbs/python/mongo/index.py b/templates/components/vectordbs/python/mongo/index.py index a80590b5..173e7b57 100644 --- a/templates/components/vectordbs/python/mongo/index.py +++ b/templates/components/vectordbs/python/mongo/index.py @@ -9,7 +9,7 @@ from llama_index.vector_stores import MongoDBAtlasVectorSearch from app.engine.context import create_service_context -def get_chat_engine(): +def get_index(): service_context = create_service_context() logger = logging.getLogger("uvicorn") logger.info("Connecting to index from MongoDB...") @@ -20,4 +20,4 @@ def get_chat_engine(): ) index = VectorStoreIndex.from_vector_store(store, service_context) logger.info("Finished connecting to index from MongoDB.") - return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") + return index diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index 4404c66e..8e16975b 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -1,15 +1,15 @@ import logging import os + +from app.engine.constants import STORAGE_DIR +from app.engine.context import create_service_context from llama_index import ( StorageContext, load_index_from_storage, ) -from app.engine.constants import STORAGE_DIR -from app.engine.context import create_service_context - -def get_chat_engine(): +def get_index(): service_context = create_service_context() # check if storage already exists if not os.path.exists(STORAGE_DIR): @@ -22,4 +22,4 @@ def get_chat_engine(): storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) index = load_index_from_storage(storage_context, service_context=service_context) logger.info(f"Finished loading index from {STORAGE_DIR}") - return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") + return index diff --git a/templates/components/vectordbs/python/pg/index.py b/templates/components/vectordbs/python/pg/index.py index 510c21a5..368fb432 100644 --- a/templates/components/vectordbs/python/pg/index.py +++ b/templates/components/vectordbs/python/pg/index.py @@ -6,11 +6,11 @@ from app.engine.context import create_service_context from app.engine.utils import init_pg_vector_store_from_env -def get_chat_engine(): +def get_index(): service_context = create_service_context() logger = logging.getLogger("uvicorn") logger.info("Connecting to index from PGVector...") store = init_pg_vector_store_from_env() index = VectorStoreIndex.from_vector_store(store, service_context) logger.info("Finished connecting to index from PGVector.") - return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") + return index diff --git a/templates/types/simple/fastapi/app/api/routers/chat.py b/templates/types/simple/fastapi/app/api/routers/chat.py index d4b000fd..09efdcaa 100644 --- a/templates/types/simple/fastapi/app/api/routers/chat.py +++ b/templates/types/simple/fastapi/app/api/routers/chat.py @@ -5,7 +5,7 @@ from llama_index.chat_engine.types import BaseChatEngine from llama_index.llms.base import ChatMessage from llama_index.llms.types import MessageRole from pydantic import BaseModel -from app.engine.index import get_chat_engine +from app.engine import get_chat_engine chat_router = r = APIRouter() diff --git a/templates/types/simple/fastapi/app/engine/__init__.py b/templates/types/simple/fastapi/app/engine/__init__.py index e69de29b..663b595a 100644 --- a/templates/types/simple/fastapi/app/engine/__init__.py +++ b/templates/types/simple/fastapi/app/engine/__init__.py @@ -0,0 +1,7 @@ +from llama_index.chat_engine import SimpleChatEngine + +from app.context import create_base_context + + +def get_chat_engine(): + return SimpleChatEngine.from_defaults(service_context=create_base_context()) diff --git a/templates/types/simple/fastapi/app/engine/index.py b/templates/types/simple/fastapi/app/engine/index.py deleted file mode 100644 index 663b595a..00000000 --- a/templates/types/simple/fastapi/app/engine/index.py +++ /dev/null @@ -1,7 +0,0 @@ -from llama_index.chat_engine import SimpleChatEngine - -from app.context import create_base_context - - -def get_chat_engine(): - return SimpleChatEngine.from_defaults(service_context=create_base_context()) diff --git a/templates/types/simple/fastapi/pyproject.toml b/templates/types/simple/fastapi/pyproject.toml index d1952f4a..42f5faf0 100644 --- a/templates/types/simple/fastapi/pyproject.toml +++ b/templates/types/simple/fastapi/pyproject.toml @@ -13,6 +13,8 @@ llama-index = "^0.9.19" pypdf = "^3.17.0" python-dotenv = "^1.0.0" docx2txt = "^0.8" +llama-hub = "^0.0.77" +wikipedia = "^1.4.0" [build-system] requires = ["poetry-core"] diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 26fd480d..0afe14e4 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -3,7 +3,7 @@ from typing import List from fastapi.responses import StreamingResponse from llama_index.chat_engine.types import BaseChatEngine -from app.engine.index import get_chat_engine +from app.engine import get_chat_engine from fastapi import APIRouter, Depends, HTTPException, Request, status from llama_index.llms.base import ChatMessage from llama_index.llms.types import MessageRole diff --git a/templates/types/streaming/fastapi/app/engine/__init__.py b/templates/types/streaming/fastapi/app/engine/__init__.py index e69de29b..663b595a 100644 --- a/templates/types/streaming/fastapi/app/engine/__init__.py +++ b/templates/types/streaming/fastapi/app/engine/__init__.py @@ -0,0 +1,7 @@ +from llama_index.chat_engine import SimpleChatEngine + +from app.context import create_base_context + + +def get_chat_engine(): + return SimpleChatEngine.from_defaults(service_context=create_base_context()) diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py deleted file mode 100644 index 663b595a..00000000 --- a/templates/types/streaming/fastapi/app/engine/index.py +++ /dev/null @@ -1,7 +0,0 @@ -from llama_index.chat_engine import SimpleChatEngine - -from app.context import create_base_context - - -def get_chat_engine(): - return SimpleChatEngine.from_defaults(service_context=create_base_context()) diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml index d1952f4a..42f5faf0 100644 --- a/templates/types/streaming/fastapi/pyproject.toml +++ b/templates/types/streaming/fastapi/pyproject.toml @@ -13,6 +13,8 @@ llama-index = "^0.9.19" pypdf = "^3.17.0" python-dotenv = "^1.0.0" docx2txt = "^0.8" +llama-hub = "^0.0.77" +wikipedia = "^1.4.0" [build-system] requires = ["poetry-core"] -- GitLab