Skip to content
Snippets Groups Projects
Commit 904c53cf authored by Huu Le (Lee)'s avatar Huu Le (Lee) Committed by GitHub
Browse files

feat: Add support for llamahub tools (#517)

parent 340c49d4
No related branches found
No related tags found
No related merge requests found
Showing
with 256 additions and 30 deletions
/* eslint-disable import/no-extraneous-dependencies */ /* eslint-disable import/no-extraneous-dependencies */
import path from "path"; import path from "path";
import { green } from "picocolors"; import { green, yellow } from "picocolors";
import { tryGitInit } from "./helpers/git"; import { tryGitInit } from "./helpers/git";
import { isFolderEmpty } from "./helpers/is-folder-empty"; import { isFolderEmpty } from "./helpers/is-folder-empty";
import { getOnline } from "./helpers/is-online"; import { getOnline } from "./helpers/is-online";
...@@ -12,6 +12,7 @@ import terminalLink from "terminal-link"; ...@@ -12,6 +12,7 @@ import terminalLink from "terminal-link";
import type { InstallTemplateArgs } from "./helpers"; import type { InstallTemplateArgs } from "./helpers";
import { installTemplate } from "./helpers"; import { installTemplate } from "./helpers";
import { templatesDir } from "./helpers/dir"; import { templatesDir } from "./helpers/dir";
import { toolsRequireConfig } from "./helpers/tools";
export type InstallAppArgs = Omit< export type InstallAppArgs = Omit<
InstallTemplateArgs, InstallTemplateArgs,
...@@ -38,6 +39,7 @@ export async function createApp({ ...@@ -38,6 +39,7 @@ export async function createApp({
externalPort, externalPort,
postInstallAction, postInstallAction,
dataSource, dataSource,
tools,
}: InstallAppArgs): Promise<void> { }: InstallAppArgs): Promise<void> {
const root = path.resolve(appPath); const root = path.resolve(appPath);
...@@ -82,6 +84,7 @@ export async function createApp({ ...@@ -82,6 +84,7 @@ export async function createApp({
externalPort, externalPort,
postInstallAction, postInstallAction,
dataSource, dataSource,
tools,
}; };
if (frontend) { if (frontend) {
...@@ -114,6 +117,17 @@ export async function createApp({ ...@@ -114,6 +117,17 @@ export async function createApp({
console.log(); console.log();
} }
if (toolsRequireConfig(tools)) {
console.log(
yellow(
`You have selected tools that require configuration. Please configure them in the ${terminalLink(
"tools_config.json",
`file://${root}/tools_config.json`,
)} file.`,
),
);
}
console.log("");
console.log(`${green("Success!")} Created ${appName} at ${appPath}`); console.log(`${green("Success!")} Created ${appName} at ${appPath}`);
console.log( console.log(
......
...@@ -117,6 +117,8 @@ export async function runCreateLlama( ...@@ -117,6 +117,8 @@ export async function runCreateLlama(
externalPort, externalPort,
"--post-install-action", "--post-install-action",
postInstallAction, postInstallAction,
"--tools",
"none",
].join(" "); ].join(" ");
console.log(`running command '${command}' in ${cwd}`); console.log(`running command '${command}' in ${cwd}`);
let appProcess = exec(command, { let appProcess = exec(command, {
......
...@@ -6,6 +6,7 @@ import terminalLink from "terminal-link"; ...@@ -6,6 +6,7 @@ import terminalLink from "terminal-link";
import { copy } from "./copy"; import { copy } from "./copy";
import { templatesDir } from "./dir"; import { templatesDir } from "./dir";
import { isPoetryAvailable, tryPoetryInstall } from "./poetry"; import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
import { getToolConfig } from "./tools";
import { InstallTemplateArgs, TemplateVectorDB } from "./types"; import { InstallTemplateArgs, TemplateVectorDB } from "./types";
interface Dependency { interface Dependency {
...@@ -128,6 +129,7 @@ export const installPythonTemplate = async ({ ...@@ -128,6 +129,7 @@ export const installPythonTemplate = async ({
engine, engine,
vectorDb, vectorDb,
dataSource, dataSource,
tools,
postInstallAction, postInstallAction,
}: Pick< }: Pick<
InstallTemplateArgs, InstallTemplateArgs,
...@@ -137,6 +139,7 @@ export const installPythonTemplate = async ({ ...@@ -137,6 +139,7 @@ export const installPythonTemplate = async ({
| "engine" | "engine"
| "vectorDb" | "vectorDb"
| "dataSource" | "dataSource"
| "tools"
| "postInstallAction" | "postInstallAction"
>) => { >) => {
console.log("\nInitializing Python project with template:", template, "\n"); console.log("\nInitializing Python project with template:", template, "\n");
...@@ -162,20 +165,44 @@ export const installPythonTemplate = async ({ ...@@ -162,20 +165,44 @@ export const installPythonTemplate = async ({
}); });
if (engine === "context") { if (engine === "context") {
const enginePath = path.join(root, "app", "engine");
const compPath = path.join(templatesDir, "components"); const compPath = path.join(templatesDir, "components");
let vectorDbDirName = vectorDb ?? "none";
const vectorDbDirName = vectorDb ?? "none";
const VectorDBPath = path.join( const VectorDBPath = path.join(
compPath, compPath,
"vectordbs", "vectordbs",
"python", "python",
vectorDbDirName, vectorDbDirName,
); );
const enginePath = path.join(root, "app", "engine"); await copy("**", enginePath, {
await copy("**", path.join(root, "app", "engine"), {
parents: true, parents: true,
cwd: VectorDBPath, cwd: VectorDBPath,
}); });
// Copy engine code
if (tools !== undefined && tools.length > 0) {
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "python", "agent"),
});
// Write tools_config.json
const configContent: Record<string, any> = {};
tools.forEach((tool) => {
configContent[tool] = getToolConfig(tool) ?? {};
});
const configFilePath = path.join(root, "tools_config.json");
await fs.writeFile(
configFilePath,
JSON.stringify(configContent, null, 2),
);
} else {
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "python", "chat"),
});
}
const dataSourceType = dataSource?.type; const dataSourceType = dataSource?.type;
if (dataSourceType !== undefined && dataSourceType !== "none") { if (dataSourceType !== undefined && dataSourceType !== "none") {
let loaderPath = let loaderPath =
......
export type Tool = {
display: string;
name: string;
config?: Record<string, any>;
};
export const supportedTools: Tool[] = [
{
display: "Google Search (configuration required)",
name: "google_search",
config: {
engine: "Your search engine id",
key: "Your search api key",
num: 2,
},
},
{
display: "Wikipedia",
name: "wikipedia",
},
];
export const getToolConfig = (name: string) => {
return supportedTools.find((tool) => tool.name === name)?.config;
};
export const toolsRequireConfig = (tools?: string[]): boolean => {
if (tools) {
return tools.some((tool) => getToolConfig(tool));
}
return false;
};
...@@ -41,4 +41,5 @@ export interface InstallTemplateArgs { ...@@ -41,4 +41,5 @@ export interface InstallTemplateArgs {
vectorDb?: TemplateVectorDB; vectorDb?: TemplateVectorDB;
externalPort?: number; externalPort?: number;
postInstallAction?: TemplatePostInstallAction; postInstallAction?: TemplatePostInstallAction;
tools?: string[];
} }
...@@ -11,6 +11,7 @@ import { createApp } from "./create-app"; ...@@ -11,6 +11,7 @@ import { createApp } from "./create-app";
import { getPkgManager } from "./helpers/get-pkg-manager"; import { getPkgManager } from "./helpers/get-pkg-manager";
import { isFolderEmpty } from "./helpers/is-folder-empty"; import { isFolderEmpty } from "./helpers/is-folder-empty";
import { runApp } from "./helpers/run-app"; import { runApp } from "./helpers/run-app";
import { supportedTools } from "./helpers/tools";
import { validateNpmName } from "./helpers/validate-pkg"; import { validateNpmName } from "./helpers/validate-pkg";
import packageJson from "./package.json"; import packageJson from "./package.json";
import { QuestionArgs, askQuestions, onPromptState } from "./questions"; import { QuestionArgs, askQuestions, onPromptState } from "./questions";
...@@ -146,6 +147,13 @@ const program = new Commander.Command(packageJson.name) ...@@ -146,6 +147,13 @@ const program = new Commander.Command(packageJson.name)
` `
Select which vector database you would like to use, such as 'none', 'pg' or 'mongo'. The default option is not to use a vector database and use the local filesystem instead ('none'). Select which vector database you would like to use, such as 'none', 'pg' or 'mongo'. The default option is not to use a vector database and use the local filesystem instead ('none').
`,
)
.option(
"--tools <tools>",
`
Specify the tools you want to use by providing a comma-separated list. For example, 'google_search,wikipedia'. Use 'none' to not using any tools.
`, `,
) )
.allowUnknownOption() .allowUnknownOption()
...@@ -156,6 +164,25 @@ if (process.argv.includes("--no-frontend")) { ...@@ -156,6 +164,25 @@ if (process.argv.includes("--no-frontend")) {
if (process.argv.includes("--no-eslint")) { if (process.argv.includes("--no-eslint")) {
program.eslint = false; program.eslint = false;
} }
if (process.argv.includes("--tools")) {
if (program.tools === "none") {
program.tools = [];
} else {
program.tools = program.tools.split(",");
// Check if tools are available
const toolsName = supportedTools.map((tool) => tool.name);
program.tools.forEach((tool: string) => {
if (!toolsName.includes(tool)) {
console.error(
`Error: Tool '${tool}' is not supported. Supported tools are: ${toolsName.join(
", ",
)}`,
);
process.exit(1);
}
});
}
}
const packageManager = !!program.useNpm const packageManager = !!program.useNpm
? "npm" ? "npm"
...@@ -256,6 +283,7 @@ async function run(): Promise<void> { ...@@ -256,6 +283,7 @@ async function run(): Promise<void> {
externalPort: program.externalPort, externalPort: program.externalPort,
postInstallAction: program.postInstallAction, postInstallAction: program.postInstallAction,
dataSource: program.dataSource, dataSource: program.dataSource,
tools: program.tools,
}); });
conf.set("preferences", preferences); conf.set("preferences", preferences);
......
...@@ -10,6 +10,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant"; ...@@ -10,6 +10,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
import { templatesDir } from "./helpers/dir"; import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack"; import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { getRepoRootFolders } from "./helpers/repo"; import { getRepoRootFolders } from "./helpers/repo";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
export type QuestionArgs = Omit< export type QuestionArgs = Omit<
InstallAppArgs, InstallAppArgs,
...@@ -70,6 +71,7 @@ const defaults: QuestionArgs = { ...@@ -70,6 +71,7 @@ const defaults: QuestionArgs = {
type: "none", type: "none",
config: {}, config: {},
}, },
tools: [],
}; };
const handlers = { const handlers = {
...@@ -214,7 +216,12 @@ export const askQuestions = async ( ...@@ -214,7 +216,12 @@ export const askQuestions = async (
const hasOpenAiKey = program.openAiKey || process.env["OPENAI_API_KEY"]; const hasOpenAiKey = program.openAiKey || process.env["OPENAI_API_KEY"];
const hasVectorDb = program.vectorDb && program.vectorDb !== "none"; const hasVectorDb = program.vectorDb && program.vectorDb !== "none";
if (!hasVectorDb && hasOpenAiKey) { // Can run the app if all tools do not require configuration
if (
!hasVectorDb &&
hasOpenAiKey &&
!toolsRequireConfig(program.tools)
) {
actionChoices.push({ actionChoices.push({
title: title:
"Generate code, install dependencies, and run the app (~2 min)", "Generate code, install dependencies, and run the app (~2 min)",
...@@ -563,6 +570,29 @@ export const askQuestions = async ( ...@@ -563,6 +570,29 @@ export const askQuestions = async (
} }
} }
if (
!program.tools &&
program.framework === "fastapi" &&
program.engine === "context"
) {
if (ciInfo.isCI) {
program.tools = getPrefOrDefault("tools");
} else {
const toolChoices = supportedTools.map((tool) => ({
title: tool.display,
value: tool.name,
}));
const { tools } = await prompts({
type: "multiselect",
name: "tools",
message: "Which tools would you like to use?",
choices: toolChoices,
});
program.tools = tools;
preferences.tools = tools;
}
}
if (!program.openAiKey) { if (!program.openAiKey) {
const { key } = await prompts( const { key } = await prompts(
{ {
......
import os
from typing import Any, Optional
from llama_index.llms import LLM
from llama_index.agent import AgentRunner
from app.engine.tools import ToolFactory
from app.engine.index import get_index
from llama_index.agent import ReActAgent
from llama_index.tools.query_engine import QueryEngineTool
def create_agent_from_llm(
llm: Optional[LLM] = None,
**kwargs: Any,
) -> AgentRunner:
from llama_index.agent import OpenAIAgent, ReActAgent
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
return OpenAIAgent.from_tools(
llm=llm,
**kwargs,
)
else:
return ReActAgent.from_tools(
llm=llm,
**kwargs,
)
def get_chat_engine():
tools = []
# Add query tool
index = get_index()
llm = index.service_context.llm
query_engine = index.as_query_engine(similarity_top_k=5)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
tools.append(query_engine_tool)
# Add additional tools
tools += ToolFactory.from_env()
return create_agent_from_llm(
llm=llm,
tools=tools,
verbose=True,
)
import json
import importlib
from llama_index.tools.tool_spec.base import BaseToolSpec
from llama_index.tools.function_tool import FunctionTool
class ToolFactory:
@staticmethod
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
try:
module_name = f"llama_hub.tools.{tool_name}.base"
module = importlib.import_module(module_name)
tool_cls_name = tool_name.title().replace("_", "") + "ToolSpec"
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**kwargs)
return tool_spec.to_tool_list()
except (ImportError, AttributeError) as e:
raise ValueError(f"Unsupported tool: {tool_name}") from e
except TypeError as e:
raise ValueError(
f"Could not create tool: {tool_name}. With config: {kwargs}"
) from e
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
with open("tools_config.json", "r") as f:
tool_configs = json.load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
return tools
from app.engine.index import get_index
def get_chat_engine():
return get_index().as_chat_engine(
similarity_top_k=5, chat_mode="condense_plus_context"
)
...@@ -9,7 +9,7 @@ from llama_index.vector_stores import MongoDBAtlasVectorSearch ...@@ -9,7 +9,7 @@ from llama_index.vector_stores import MongoDBAtlasVectorSearch
from app.engine.context import create_service_context from app.engine.context import create_service_context
def get_chat_engine(): def get_index():
service_context = create_service_context() service_context = create_service_context()
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
logger.info("Connecting to index from MongoDB...") logger.info("Connecting to index from MongoDB...")
...@@ -20,4 +20,4 @@ def get_chat_engine(): ...@@ -20,4 +20,4 @@ def get_chat_engine():
) )
index = VectorStoreIndex.from_vector_store(store, service_context) index = VectorStoreIndex.from_vector_store(store, service_context)
logger.info("Finished connecting to index from MongoDB.") logger.info("Finished connecting to index from MongoDB.")
return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") return index
import logging import logging
import os import os
from app.engine.constants import STORAGE_DIR
from app.engine.context import create_service_context
from llama_index import ( from llama_index import (
StorageContext, StorageContext,
load_index_from_storage, load_index_from_storage,
) )
from app.engine.constants import STORAGE_DIR
from app.engine.context import create_service_context
def get_chat_engine(): def get_index():
service_context = create_service_context() service_context = create_service_context()
# check if storage already exists # check if storage already exists
if not os.path.exists(STORAGE_DIR): if not os.path.exists(STORAGE_DIR):
...@@ -22,4 +22,4 @@ def get_chat_engine(): ...@@ -22,4 +22,4 @@ def get_chat_engine():
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context, service_context=service_context) index = load_index_from_storage(storage_context, service_context=service_context)
logger.info(f"Finished loading index from {STORAGE_DIR}") logger.info(f"Finished loading index from {STORAGE_DIR}")
return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") return index
...@@ -6,11 +6,11 @@ from app.engine.context import create_service_context ...@@ -6,11 +6,11 @@ from app.engine.context import create_service_context
from app.engine.utils import init_pg_vector_store_from_env from app.engine.utils import init_pg_vector_store_from_env
def get_chat_engine(): def get_index():
service_context = create_service_context() service_context = create_service_context()
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
logger.info("Connecting to index from PGVector...") logger.info("Connecting to index from PGVector...")
store = init_pg_vector_store_from_env() store = init_pg_vector_store_from_env()
index = VectorStoreIndex.from_vector_store(store, service_context) index = VectorStoreIndex.from_vector_store(store, service_context)
logger.info("Finished connecting to index from PGVector.") logger.info("Finished connecting to index from PGVector.")
return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context") return index
...@@ -5,7 +5,7 @@ from llama_index.chat_engine.types import BaseChatEngine ...@@ -5,7 +5,7 @@ from llama_index.chat_engine.types import BaseChatEngine
from llama_index.llms.base import ChatMessage from llama_index.llms.base import ChatMessage
from llama_index.llms.types import MessageRole from llama_index.llms.types import MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from app.engine.index import get_chat_engine from app.engine import get_chat_engine
chat_router = r = APIRouter() chat_router = r = APIRouter()
......
from llama_index.chat_engine import SimpleChatEngine
from app.context import create_base_context
def get_chat_engine():
return SimpleChatEngine.from_defaults(service_context=create_base_context())
from llama_index.chat_engine import SimpleChatEngine
from app.context import create_base_context
def get_chat_engine():
return SimpleChatEngine.from_defaults(service_context=create_base_context())
...@@ -13,6 +13,8 @@ llama-index = "^0.9.19" ...@@ -13,6 +13,8 @@ llama-index = "^0.9.19"
pypdf = "^3.17.0" pypdf = "^3.17.0"
python-dotenv = "^1.0.0" python-dotenv = "^1.0.0"
docx2txt = "^0.8" docx2txt = "^0.8"
llama-hub = "^0.0.77"
wikipedia = "^1.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from llama_index.chat_engine.types import BaseChatEngine from llama_index.chat_engine.types import BaseChatEngine
from app.engine.index import get_chat_engine from app.engine import get_chat_engine
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.llms.base import ChatMessage from llama_index.llms.base import ChatMessage
from llama_index.llms.types import MessageRole from llama_index.llms.types import MessageRole
......
from llama_index.chat_engine import SimpleChatEngine
from app.context import create_base_context
def get_chat_engine():
return SimpleChatEngine.from_defaults(service_context=create_base_context())
from llama_index.chat_engine import SimpleChatEngine
from app.context import create_base_context
def get_chat_engine():
return SimpleChatEngine.from_defaults(service_context=create_base_context())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment