From 904c53cf769981e220d65229d9e29ed4957d7669 Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Tue, 6 Feb 2024 17:34:03 +0700
Subject: [PATCH] feat: Add support for llamahub tools (#517)

Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
---
 create-app.ts                                 | 16 +++++-
 e2e/utils.ts                                  |  2 +
 helpers/python.ts                             | 33 ++++++++++--
 helpers/tools.ts                              | 32 ++++++++++++
 helpers/types.ts                              |  1 +
 index.ts                                      | 28 +++++++++++
 questions.ts                                  | 32 +++++++++++-
 .../engines/python/agent/__init__.py          | 50 +++++++++++++++++++
 .../components/engines/python/agent/tools.py  | 33 ++++++++++++
 .../engines/python/chat/__init__.py           |  7 +++
 .../vectordbs/python/mongo/index.py           |  4 +-
 .../components/vectordbs/python/none/index.py | 10 ++--
 .../components/vectordbs/python/pg/index.py   |  4 +-
 .../simple/fastapi/app/api/routers/chat.py    |  2 +-
 .../simple/fastapi/app/engine/__init__.py     |  7 +++
 .../types/simple/fastapi/app/engine/index.py  |  7 ---
 templates/types/simple/fastapi/pyproject.toml |  2 +
 .../streaming/fastapi/app/api/routers/chat.py |  2 +-
 .../streaming/fastapi/app/engine/__init__.py  |  7 +++
 .../streaming/fastapi/app/engine/index.py     |  7 ---
 .../types/streaming/fastapi/pyproject.toml    |  2 +
 21 files changed, 258 insertions(+), 30 deletions(-)
 create mode 100644 helpers/tools.ts
 create mode 100644 templates/components/engines/python/agent/__init__.py
 create mode 100644 templates/components/engines/python/agent/tools.py
 create mode 100644 templates/components/engines/python/chat/__init__.py
 delete mode 100644 templates/types/simple/fastapi/app/engine/index.py
 delete mode 100644 templates/types/streaming/fastapi/app/engine/index.py

diff --git a/create-app.ts b/create-app.ts
index 4c1eaf4d..835c02f1 100644
--- a/create-app.ts
+++ b/create-app.ts
@@ -1,6 +1,6 @@
 /* eslint-disable import/no-extraneous-dependencies */
 import path from "path";
-import { green } from "picocolors";
+import { green, yellow } from "picocolors";
 import { tryGitInit } from "./helpers/git";
 import { isFolderEmpty } from "./helpers/is-folder-empty";
 import { getOnline } from "./helpers/is-online";
@@ -12,6 +12,7 @@ import terminalLink from "terminal-link";
 import type { InstallTemplateArgs } from "./helpers";
 import { installTemplate } from "./helpers";
 import { templatesDir } from "./helpers/dir";
+import { toolsRequireConfig } from "./helpers/tools";
 
 export type InstallAppArgs = Omit<
   InstallTemplateArgs,
@@ -38,6 +39,7 @@ export async function createApp({
   externalPort,
   postInstallAction,
   dataSource,
+  tools,
 }: InstallAppArgs): Promise<void> {
   const root = path.resolve(appPath);
 
@@ -82,6 +84,7 @@ export async function createApp({
     externalPort,
     postInstallAction,
     dataSource,
+    tools,
   };
 
   if (frontend) {
@@ -114,6 +117,17 @@ export async function createApp({
     console.log();
   }
 
+  if (toolsRequireConfig(tools)) {
+    console.log(
+      yellow(
+        `You have selected tools that require configuration. Please configure them in the ${terminalLink(
+          "tools_config.json",
+          `file://${root}/tools_config.json`,
+        )} file.`,
+      ),
+    );
+  }
+  console.log("");
   console.log(`${green("Success!")} Created ${appName} at ${appPath}`);
 
   console.log(
diff --git a/e2e/utils.ts b/e2e/utils.ts
index b06776da..2ec35e40 100644
--- a/e2e/utils.ts
+++ b/e2e/utils.ts
@@ -117,6 +117,8 @@ export async function runCreateLlama(
     externalPort,
     "--post-install-action",
     postInstallAction,
+    "--tools",
+    "none",
   ].join(" ");
   console.log(`running command '${command}' in ${cwd}`);
   let appProcess = exec(command, {
diff --git a/helpers/python.ts b/helpers/python.ts
index 8b2ec6d9..6fbc48b5 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -6,6 +6,7 @@ import terminalLink from "terminal-link";
 import { copy } from "./copy";
 import { templatesDir } from "./dir";
 import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
+import { getToolConfig } from "./tools";
 import { InstallTemplateArgs, TemplateVectorDB } from "./types";
 
 interface Dependency {
@@ -128,6 +129,7 @@ export const installPythonTemplate = async ({
   engine,
   vectorDb,
   dataSource,
+  tools,
   postInstallAction,
 }: Pick<
   InstallTemplateArgs,
@@ -137,6 +139,7 @@ export const installPythonTemplate = async ({
   | "engine"
   | "vectorDb"
   | "dataSource"
+  | "tools"
   | "postInstallAction"
 >) => {
   console.log("\nInitializing Python project with template:", template, "\n");
@@ -162,20 +165,44 @@ export const installPythonTemplate = async ({
   });
 
   if (engine === "context") {
+    const enginePath = path.join(root, "app", "engine");
     const compPath = path.join(templatesDir, "components");
-    let vectorDbDirName = vectorDb ?? "none";
+
+    const vectorDbDirName = vectorDb ?? "none";
     const VectorDBPath = path.join(
       compPath,
       "vectordbs",
       "python",
       vectorDbDirName,
     );
-    const enginePath = path.join(root, "app", "engine");
-    await copy("**", path.join(root, "app", "engine"), {
+    await copy("**", enginePath, {
       parents: true,
       cwd: VectorDBPath,
     });
 
+    // Copy engine code
+    if (tools !== undefined && tools.length > 0) {
+      await copy("**", enginePath, {
+        parents: true,
+        cwd: path.join(compPath, "engines", "python", "agent"),
+      });
+      // Write tools_config.json
+      const configContent: Record<string, any> = {};
+      tools.forEach((tool) => {
+        configContent[tool] = getToolConfig(tool) ?? {};
+      });
+      const configFilePath = path.join(root, "tools_config.json");
+      await fs.writeFile(
+        configFilePath,
+        JSON.stringify(configContent, null, 2),
+      );
+    } else {
+      await copy("**", enginePath, {
+        parents: true,
+        cwd: path.join(compPath, "engines", "python", "chat"),
+      });
+    }
+
     const dataSourceType = dataSource?.type;
     if (dataSourceType !== undefined && dataSourceType !== "none") {
       let loaderPath =
diff --git a/helpers/tools.ts b/helpers/tools.ts
new file mode 100644
index 00000000..13b8ff4b
--- /dev/null
+++ b/helpers/tools.ts
@@ -0,0 +1,32 @@
+export type Tool = {
+  display: string;
+  name: string;
+  config?: Record<string, any>;
+};
+
+export const supportedTools: Tool[] = [
+  {
+    display: "Google Search (configuration required)",
+    name: "google_search",
+    config: {
+      engine: "Your search engine id",
+      key: "Your search api key",
+      num: 2,
+    },
+  },
+  {
+    display: "Wikipedia",
+    name: "wikipedia",
+  },
+];
+
+export const getToolConfig = (name: string) => {
+  return supportedTools.find((tool) => tool.name === name)?.config;
+};
+
+export const toolsRequireConfig = (tools?: string[]): boolean => {
+  if (tools) {
+    return tools.some((tool) => getToolConfig(tool));
+  }
+  return false;
+};
diff --git a/helpers/types.ts b/helpers/types.ts
index 191e028f..5e4a9f6e 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -41,4 +41,5 @@ export interface InstallTemplateArgs {
   vectorDb?: TemplateVectorDB;
   externalPort?: number;
   postInstallAction?: TemplatePostInstallAction;
+  tools?: string[];
 }
diff --git a/index.ts b/index.ts
index 601742c2..6fa19cc5 100644
--- a/index.ts
+++ b/index.ts
@@ -11,6 +11,7 @@ import { createApp } from "./create-app";
 import { getPkgManager } from "./helpers/get-pkg-manager";
 import { isFolderEmpty } from "./helpers/is-folder-empty";
 import { runApp } from "./helpers/run-app";
+import { supportedTools } from "./helpers/tools";
 import { validateNpmName } from "./helpers/validate-pkg";
 import packageJson from "./package.json";
 import { QuestionArgs, askQuestions, onPromptState } from "./questions";
@@ -146,6 +147,13 @@ const program = new Commander.Command(packageJson.name)
     `
 
   Select which vector database you would like to use, such as 'none', 'pg' or 'mongo'. The default option is not to use a vector database and use the local filesystem instead ('none').
+`,
+  )
+  .option(
+    "--tools <tools>",
+    `
+
+  Specify the tools you want to use by providing a comma-separated list. For example, 'google_search,wikipedia'. Use 'none' to not using any tools.
 `,
   )
   .allowUnknownOption()
@@ -156,6 +164,25 @@ if (process.argv.includes("--no-frontend")) {
 if (process.argv.includes("--no-eslint")) {
   program.eslint = false;
 }
+if (process.argv.includes("--tools")) {
+  if (program.tools === "none") {
+    program.tools = [];
+  } else {
+    program.tools = program.tools.split(",");
+    // Check if tools are available
+    const toolsName = supportedTools.map((tool) => tool.name);
+    program.tools.forEach((tool: string) => {
+      if (!toolsName.includes(tool)) {
+        console.error(
+          `Error: Tool '${tool}' is not supported. Supported tools are: ${toolsName.join(
+            ", ",
+          )}`,
+        );
+        process.exit(1);
+      }
+    });
+  }
+}
 
 const packageManager = !!program.useNpm
   ? "npm"
@@ -256,6 +283,7 @@ async function run(): Promise<void> {
     externalPort: program.externalPort,
     postInstallAction: program.postInstallAction,
     dataSource: program.dataSource,
+    tools: program.tools,
   });
   conf.set("preferences", preferences);
 
diff --git a/questions.ts b/questions.ts
index 0e671184..d6f29b8b 100644
--- a/questions.ts
+++ b/questions.ts
@@ -10,6 +10,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
 import { templatesDir } from "./helpers/dir";
 import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
 import { getRepoRootFolders } from "./helpers/repo";
+import { supportedTools, toolsRequireConfig } from "./helpers/tools";
 
 export type QuestionArgs = Omit<
   InstallAppArgs,
@@ -70,6 +71,7 @@ const defaults: QuestionArgs = {
     type: "none",
     config: {},
   },
+  tools: [],
 };
 
 const handlers = {
@@ -214,7 +216,12 @@ export const askQuestions = async (
 
         const hasOpenAiKey = program.openAiKey || process.env["OPENAI_API_KEY"];
         const hasVectorDb = program.vectorDb && program.vectorDb !== "none";
-        if (!hasVectorDb && hasOpenAiKey) {
+        // Can run the app if all tools do not require configuration
+        if (
+          !hasVectorDb &&
+          hasOpenAiKey &&
+          !toolsRequireConfig(program.tools)
+        ) {
           actionChoices.push({
             title:
               "Generate code, install dependencies, and run the app (~2 min)",
@@ -563,6 +570,29 @@ export const askQuestions = async (
     }
   }
 
+  if (
+    !program.tools &&
+    program.framework === "fastapi" &&
+    program.engine === "context"
+  ) {
+    if (ciInfo.isCI) {
+      program.tools = getPrefOrDefault("tools");
+    } else {
+      const toolChoices = supportedTools.map((tool) => ({
+        title: tool.display,
+        value: tool.name,
+      }));
+      const { tools } = await prompts({
+        type: "multiselect",
+        name: "tools",
+        message: "Which tools would you like to use?",
+        choices: toolChoices,
+      });
+      program.tools = tools;
+      preferences.tools = tools;
+    }
+  }
+
   if (!program.openAiKey) {
     const { key } = await prompts(
       {
diff --git a/templates/components/engines/python/agent/__init__.py b/templates/components/engines/python/agent/__init__.py
new file mode 100644
index 00000000..f1b62b87
--- /dev/null
+++ b/templates/components/engines/python/agent/__init__.py
@@ -0,0 +1,50 @@
+import os
+
+from typing import Any, Optional
+from llama_index.llms import LLM
+from llama_index.agent import AgentRunner
+
+from app.engine.tools import ToolFactory
+from app.engine.index import get_index
+from llama_index.agent import ReActAgent
+from llama_index.tools.query_engine import QueryEngineTool
+
+
+def create_agent_from_llm(
+    llm: Optional[LLM] = None,
+    **kwargs: Any,
+) -> AgentRunner:
+    from llama_index.agent import OpenAIAgent, ReActAgent
+    from llama_index.llms.openai import OpenAI
+    from llama_index.llms.openai_utils import is_function_calling_model
+
+    if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
+        return OpenAIAgent.from_tools(
+            llm=llm,
+            **kwargs,
+        )
+    else:
+        return ReActAgent.from_tools(
+            llm=llm,
+            **kwargs,
+        )
+
+
+def get_chat_engine():
+    tools = []
+
+    # Add query tool
+    index = get_index()
+    llm = index.service_context.llm
+    query_engine = index.as_query_engine(similarity_top_k=5)
+    query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
+    tools.append(query_engine_tool)
+
+    # Add additional tools
+    tools += ToolFactory.from_env()
+
+    return create_agent_from_llm(
+        llm=llm,
+        tools=tools,
+        verbose=True,
+    )
diff --git a/templates/components/engines/python/agent/tools.py b/templates/components/engines/python/agent/tools.py
new file mode 100644
index 00000000..9fb9d488
--- /dev/null
+++ b/templates/components/engines/python/agent/tools.py
@@ -0,0 +1,33 @@
+import json
+import importlib
+
+from llama_index.tools.tool_spec.base import BaseToolSpec
+from llama_index.tools.function_tool import FunctionTool
+
+
+class ToolFactory:
+
+    @staticmethod
+    def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
+        try:
+            module_name = f"llama_hub.tools.{tool_name}.base"
+            module = importlib.import_module(module_name)
+            tool_cls_name = tool_name.title().replace("_", "") + "ToolSpec"
+            tool_class = getattr(module, tool_cls_name)
+            tool_spec: BaseToolSpec = tool_class(**kwargs)
+            return tool_spec.to_tool_list()
+        except (ImportError, AttributeError) as e:
+            raise ValueError(f"Unsupported tool: {tool_name}") from e
+        except TypeError as e:
+            raise ValueError(
+                f"Could not create tool: {tool_name}. With config: {kwargs}"
+            ) from e
+
+    @staticmethod
+    def from_env() -> list[FunctionTool]:
+        tools = []
+        with open("tools_config.json", "r") as f:
+            tool_configs = json.load(f)
+            for name, config in tool_configs.items():
+                tools += ToolFactory.create_tool(name, **config)
+        return tools
diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py
new file mode 100644
index 00000000..18a6039b
--- /dev/null
+++ b/templates/components/engines/python/chat/__init__.py
@@ -0,0 +1,7 @@
+from app.engine.index import get_index
+
+
+def get_chat_engine():
+    return get_index().as_chat_engine(
+        similarity_top_k=5, chat_mode="condense_plus_context"
+    )
diff --git a/templates/components/vectordbs/python/mongo/index.py b/templates/components/vectordbs/python/mongo/index.py
index a80590b5..173e7b57 100644
--- a/templates/components/vectordbs/python/mongo/index.py
+++ b/templates/components/vectordbs/python/mongo/index.py
@@ -9,7 +9,7 @@ from llama_index.vector_stores import MongoDBAtlasVectorSearch
 from app.engine.context import create_service_context
 
 
-def get_chat_engine():
+def get_index():
     service_context = create_service_context()
     logger = logging.getLogger("uvicorn")
     logger.info("Connecting to index from MongoDB...")
@@ -20,4 +20,4 @@ def get_chat_engine():
     )
     index = VectorStoreIndex.from_vector_store(store, service_context)
     logger.info("Finished connecting to index from MongoDB.")
-    return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context")
+    return index
diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py
index 4404c66e..8e16975b 100644
--- a/templates/components/vectordbs/python/none/index.py
+++ b/templates/components/vectordbs/python/none/index.py
@@ -1,15 +1,15 @@
 import logging
 import os
+
+from app.engine.constants import STORAGE_DIR
+from app.engine.context import create_service_context
 from llama_index import (
     StorageContext,
     load_index_from_storage,
 )
 
-from app.engine.constants import STORAGE_DIR
-from app.engine.context import create_service_context
-
 
-def get_chat_engine():
+def get_index():
     service_context = create_service_context()
     # check if storage already exists
     if not os.path.exists(STORAGE_DIR):
@@ -22,4 +22,4 @@ def get_chat_engine():
     storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
     index = load_index_from_storage(storage_context, service_context=service_context)
     logger.info(f"Finished loading index from {STORAGE_DIR}")
-    return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context")
+    return index
diff --git a/templates/components/vectordbs/python/pg/index.py b/templates/components/vectordbs/python/pg/index.py
index 510c21a5..368fb432 100644
--- a/templates/components/vectordbs/python/pg/index.py
+++ b/templates/components/vectordbs/python/pg/index.py
@@ -6,11 +6,11 @@ from app.engine.context import create_service_context
 from app.engine.utils import init_pg_vector_store_from_env
 
 
-def get_chat_engine():
+def get_index():
     service_context = create_service_context()
     logger = logging.getLogger("uvicorn")
     logger.info("Connecting to index from PGVector...")
     store = init_pg_vector_store_from_env()
     index = VectorStoreIndex.from_vector_store(store, service_context)
     logger.info("Finished connecting to index from PGVector.")
-    return index.as_chat_engine(similarity_top_k=5, chat_mode="condense_plus_context")
+    return index
diff --git a/templates/types/simple/fastapi/app/api/routers/chat.py b/templates/types/simple/fastapi/app/api/routers/chat.py
index d4b000fd..09efdcaa 100644
--- a/templates/types/simple/fastapi/app/api/routers/chat.py
+++ b/templates/types/simple/fastapi/app/api/routers/chat.py
@@ -5,7 +5,7 @@ from llama_index.chat_engine.types import BaseChatEngine
 from llama_index.llms.base import ChatMessage
 from llama_index.llms.types import MessageRole
 from pydantic import BaseModel
-from app.engine.index import get_chat_engine
+from app.engine import get_chat_engine
 
 chat_router = r = APIRouter()
 
diff --git a/templates/types/simple/fastapi/app/engine/__init__.py b/templates/types/simple/fastapi/app/engine/__init__.py
index e69de29b..663b595a 100644
--- a/templates/types/simple/fastapi/app/engine/__init__.py
+++ b/templates/types/simple/fastapi/app/engine/__init__.py
@@ -0,0 +1,7 @@
+from llama_index.chat_engine import SimpleChatEngine
+
+from app.context import create_base_context
+
+
+def get_chat_engine():
+    return SimpleChatEngine.from_defaults(service_context=create_base_context())
diff --git a/templates/types/simple/fastapi/app/engine/index.py b/templates/types/simple/fastapi/app/engine/index.py
deleted file mode 100644
index 663b595a..00000000
--- a/templates/types/simple/fastapi/app/engine/index.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from llama_index.chat_engine import SimpleChatEngine
-
-from app.context import create_base_context
-
-
-def get_chat_engine():
-    return SimpleChatEngine.from_defaults(service_context=create_base_context())
diff --git a/templates/types/simple/fastapi/pyproject.toml b/templates/types/simple/fastapi/pyproject.toml
index d1952f4a..42f5faf0 100644
--- a/templates/types/simple/fastapi/pyproject.toml
+++ b/templates/types/simple/fastapi/pyproject.toml
@@ -13,6 +13,8 @@ llama-index = "^0.9.19"
 pypdf = "^3.17.0"
 python-dotenv = "^1.0.0"
 docx2txt = "^0.8"
+llama-hub = "^0.0.77"
+wikipedia = "^1.4.0"
 
 [build-system]
 requires = ["poetry-core"]
diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py
index 26fd480d..0afe14e4 100644
--- a/templates/types/streaming/fastapi/app/api/routers/chat.py
+++ b/templates/types/streaming/fastapi/app/api/routers/chat.py
@@ -3,7 +3,7 @@ from typing import List
 from fastapi.responses import StreamingResponse
 from llama_index.chat_engine.types import BaseChatEngine
 
-from app.engine.index import get_chat_engine
+from app.engine import get_chat_engine
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from llama_index.llms.base import ChatMessage
 from llama_index.llms.types import MessageRole
diff --git a/templates/types/streaming/fastapi/app/engine/__init__.py b/templates/types/streaming/fastapi/app/engine/__init__.py
index e69de29b..663b595a 100644
--- a/templates/types/streaming/fastapi/app/engine/__init__.py
+++ b/templates/types/streaming/fastapi/app/engine/__init__.py
@@ -0,0 +1,7 @@
+from llama_index.chat_engine import SimpleChatEngine
+
+from app.context import create_base_context
+
+
+def get_chat_engine():
+    return SimpleChatEngine.from_defaults(service_context=create_base_context())
diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py
deleted file mode 100644
index 663b595a..00000000
--- a/templates/types/streaming/fastapi/app/engine/index.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from llama_index.chat_engine import SimpleChatEngine
-
-from app.context import create_base_context
-
-
-def get_chat_engine():
-    return SimpleChatEngine.from_defaults(service_context=create_base_context())
diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml
index d1952f4a..42f5faf0 100644
--- a/templates/types/streaming/fastapi/pyproject.toml
+++ b/templates/types/streaming/fastapi/pyproject.toml
@@ -13,6 +13,8 @@ llama-index = "^0.9.19"
 pypdf = "^3.17.0"
 python-dotenv = "^1.0.0"
 docx2txt = "^0.8"
+llama-hub = "^0.0.77"
+wikipedia = "^1.4.0"
 
 [build-system]
 requires = ["poetry-core"]
-- 
GitLab