From 09d4e36200347b06ef2af4596b9c6e67f9cea9bf Mon Sep 17 00:00:00 2001
From: thucpn <thucsh2@gmail.com>
Date: Mon, 25 Dec 2023 15:48:53 +0700
Subject: [PATCH] feat: create chat engine folder for python

---
 packages/create-llama/questions.ts            | 18 +++++--
 .../vectordbs/python/none}/constants.py       |  0
 .../vectordbs/python/none}/context.py         |  0
 .../vectordbs/python/none}/generate.py        |  0
 .../components/vectordbs/python/none/index.py | 25 ++++++++++
 packages/create-llama/templates/index.ts      | 12 ++++-
 .../simple/fastapi/app/api/routers/chat.py    |  5 +-
 .../app/{context.py => engine/index.py}       |  4 ++
 .../types/streaming/fastapi/app/context.py    | 11 -----
 .../streaming/fastapi/app/engine/__init__.py  |  0
 .../streaming/fastapi/app/engine/index.py     | 47 +++++++++++++------
 11 files changed, 90 insertions(+), 32 deletions(-)
 rename packages/create-llama/templates/{types/streaming/fastapi/app/engine => components/vectordbs/python/none}/constants.py (100%)
 rename packages/create-llama/templates/{types/streaming/fastapi/app/engine => components/vectordbs/python/none}/context.py (100%)
 rename packages/create-llama/templates/{types/streaming/fastapi/app/engine => components/vectordbs/python/none}/generate.py (100%)
 create mode 100644 packages/create-llama/templates/components/vectordbs/python/none/index.py
 rename packages/create-llama/templates/types/simple/fastapi/app/{context.py => engine/index.py} (94%)
 delete mode 100644 packages/create-llama/templates/types/streaming/fastapi/app/context.py
 delete mode 100644 packages/create-llama/templates/types/streaming/fastapi/app/engine/__init__.py

diff --git a/packages/create-llama/questions.ts b/packages/create-llama/questions.ts
index 109ac9fd8..42fd144c0 100644
--- a/packages/create-llama/questions.ts
+++ b/packages/create-llama/questions.ts
@@ -189,7 +189,11 @@ export const askQuestions = async (
     }
   }
 
-  if (program.framework === "express" || program.framework === "nextjs") {
+  if (
+    program.framework === "express" ||
+    program.framework === "nextjs" ||
+    program.framework === "fastapi"
+  ) {
     if (!program.model) {
       if (ciInfo.isCI) {
         program.model = getPrefOrDefault("model");
@@ -218,7 +222,11 @@ export const askQuestions = async (
     }
   }
 
-  if (program.framework === "express" || program.framework === "nextjs") {
+  if (
+    program.framework === "express" ||
+    program.framework === "nextjs" ||
+    program.framework === "fastapi"
+  ) {
     if (!program.engine) {
       if (ciInfo.isCI) {
         program.engine = getPrefOrDefault("engine");
@@ -243,7 +251,11 @@ export const askQuestions = async (
         preferences.engine = engine;
       }
     }
-    if (program.engine !== "simple" && !program.vectorDb) {
+    if (
+      program.engine !== "simple" &&
+      !program.vectorDb &&
+      program.framework !== "fastapi"
+    ) {
       if (ciInfo.isCI) {
         program.vectorDb = getPrefOrDefault("vectorDb");
       } else {
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/constants.py b/packages/create-llama/templates/components/vectordbs/python/none/constants.py
similarity index 100%
rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/constants.py
rename to packages/create-llama/templates/components/vectordbs/python/none/constants.py
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/context.py b/packages/create-llama/templates/components/vectordbs/python/none/context.py
similarity index 100%
rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/context.py
rename to packages/create-llama/templates/components/vectordbs/python/none/context.py
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/generate.py b/packages/create-llama/templates/components/vectordbs/python/none/generate.py
similarity index 100%
rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/generate.py
rename to packages/create-llama/templates/components/vectordbs/python/none/generate.py
diff --git a/packages/create-llama/templates/components/vectordbs/python/none/index.py b/packages/create-llama/templates/components/vectordbs/python/none/index.py
new file mode 100644
index 000000000..8f7d36030
--- /dev/null
+++ b/packages/create-llama/templates/components/vectordbs/python/none/index.py
@@ -0,0 +1,25 @@
+import logging
+import os
+from llama_index import (
+    StorageContext,
+    load_index_from_storage,
+)
+
+from app.engine.constants import STORAGE_DIR
+from app.engine.context import create_service_context
+
+
+def get_chat_engine():
+    service_context = create_service_context()
+    # check if storage already exists
+    if not os.path.exists(STORAGE_DIR):
+        raise Exception(
+            "StorageContext is empty - call 'npm run generate' to generate the storage first"
+        )
+    logger = logging.getLogger("uvicorn")
+    # load the existing index
+    logger.info(f"Loading index from {STORAGE_DIR}...")
+    storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
+    index = load_index_from_storage(storage_context, service_context=service_context)
+    logger.info(f"Finished loading index from {STORAGE_DIR}")
+    return index.as_chat_engine()
diff --git a/packages/create-llama/templates/index.ts b/packages/create-llama/templates/index.ts
index 2757f002c..c9af6c180 100644
--- a/packages/create-llama/templates/index.ts
+++ b/packages/create-llama/templates/index.ts
@@ -311,7 +311,8 @@ const installPythonTemplate = async ({
   root,
   template,
   framework,
-}: Pick<InstallTemplateArgs, "root" | "framework" | "template">) => {
+  engine,
+}: Pick<InstallTemplateArgs, "root" | "framework" | "template" | "engine">) => {
   console.log("\nInitializing Python project with template:", template, "\n");
   const templatePath = path.join(__dirname, "types", template, framework);
   await copy("**", root, {
@@ -334,6 +335,15 @@ const installPythonTemplate = async ({
     },
   });
 
+  if (engine === "context") {
+    const compPath = path.join(__dirname, "components");
+    const VectorDBPath = path.join(compPath, "vectordbs", "python", "none");
+    await copy("**", path.join(root, "app", "engine"), {
+      parents: true,
+      cwd: VectorDBPath,
+    });
+  }
+
   console.log(
     "\nPython project, dependencies won't be installed automatically.\n",
   );
diff --git a/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py b/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py
index e728db0dc..b90820550 100644
--- a/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py
+++ b/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
 from llama_index import VectorStoreIndex
 from llama_index.llms.base import MessageRole, ChatMessage
 from pydantic import BaseModel
-from app.context import get_index
+from app.engine.index import get_chat_engine
 
 chat_router = r = APIRouter()
 
@@ -25,7 +25,7 @@ class _Result(BaseModel):
 @r.post("")
 async def chat(
     data: _ChatData,
-    index: VectorStoreIndex = Depends(get_index),
+    chat_engine: VectorStoreIndex = Depends(get_chat_engine),
 ) -> _Result:
     # check preconditions and get last message
     if len(data.messages) == 0:
@@ -49,7 +49,6 @@ async def chat(
     ]
 
     # query chat engine
-    chat_engine = index.as_chat_engine()
     response = chat_engine.chat(lastMessage.content, messages)
     return _Result(
         result=_Message(role=MessageRole.ASSISTANT, content=response.response)
diff --git a/packages/create-llama/templates/types/simple/fastapi/app/context.py b/packages/create-llama/templates/types/simple/fastapi/app/engine/index.py
similarity index 94%
rename from packages/create-llama/templates/types/simple/fastapi/app/context.py
rename to packages/create-llama/templates/types/simple/fastapi/app/engine/index.py
index 48ca79a90..8a5bfa5c5 100644
--- a/packages/create-llama/templates/types/simple/fastapi/app/context.py
+++ b/packages/create-llama/templates/types/simple/fastapi/app/engine/index.py
@@ -38,3 +38,7 @@ def get_index():
         index = load_index_from_storage(storage_context,service_context=service_context)
         logger.info(f"Finished loading index from {STORAGE_DIR}")
     return index
+
+def get_chat_engine():
+    index = get_index()
+    return index.as_chat_engine()
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/context.py b/packages/create-llama/templates/types/streaming/fastapi/app/context.py
deleted file mode 100644
index ae00de217..000000000
--- a/packages/create-llama/templates/types/streaming/fastapi/app/context.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import os
-
-from llama_index import ServiceContext
-from llama_index.llms import OpenAI
-
-
-def create_base_context():
-    model = os.getenv("MODEL", "gpt-3.5-turbo")
-    return ServiceContext.from_defaults(
-        llm=OpenAI(model=model),
-    )
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/__init__.py b/packages/create-llama/templates/types/streaming/fastapi/app/engine/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py b/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py
index 8f7d36030..8a5bfa5c5 100644
--- a/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py
+++ b/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py
@@ -1,25 +1,44 @@
-import logging
 import os
+import logging
+
 from llama_index import (
+    SimpleDirectoryReader,
     StorageContext,
+    VectorStoreIndex,
     load_index_from_storage,
+    ServiceContext,
 )
+from llama_index.llms import OpenAI
 
-from app.engine.constants import STORAGE_DIR
-from app.engine.context import create_service_context
+STORAGE_DIR = "./storage"  # directory to cache the generated index
+DATA_DIR = "./data"  # directory containing the documents to index
 
+def create_base_context():
+    model = os.getenv("MODEL", "gpt-3.5-turbo")
+    return ServiceContext.from_defaults(
+        llm=OpenAI(model=model),
+    )
 
-def get_chat_engine():
-    service_context = create_service_context()
+def get_index():
+    service_context = create_base_context()
+    logger = logging.getLogger("uvicorn")
     # check if storage already exists
     if not os.path.exists(STORAGE_DIR):
-        raise Exception(
-            "StorageContext is empty - call 'npm run generate' to generate the storage first"
-        )
-    logger = logging.getLogger("uvicorn")
-    # load the existing index
-    logger.info(f"Loading index from {STORAGE_DIR}...")
-    storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
-    index = load_index_from_storage(storage_context, service_context=service_context)
-    logger.info(f"Finished loading index from {STORAGE_DIR}")
+        logger.info("Creating new index")
+        # load the documents and create the index
+        documents = SimpleDirectoryReader(DATA_DIR).load_data()
+        index = VectorStoreIndex.from_documents(documents,service_context=service_context)
+        # store it for later
+        index.storage_context.persist(STORAGE_DIR)
+        logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
+    else:
+        # load the existing index
+        logger.info(f"Loading index from {STORAGE_DIR}...")
+        storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
+        index = load_index_from_storage(storage_context,service_context=service_context)
+        logger.info(f"Finished loading index from {STORAGE_DIR}")
+    return index
+
+def get_chat_engine():
+    index = get_index()
     return index.as_chat_engine()
-- 
GitLab