diff --git a/packages/create-llama/questions.ts b/packages/create-llama/questions.ts index 109ac9fd8a27a4315949306041620e5f025d6b27..42fd144c0102fc40bf7e5481b6694f27787af827 100644 --- a/packages/create-llama/questions.ts +++ b/packages/create-llama/questions.ts @@ -189,7 +189,11 @@ export const askQuestions = async ( } } - if (program.framework === "express" || program.framework === "nextjs") { + if ( + program.framework === "express" || + program.framework === "nextjs" || + program.framework === "fastapi" + ) { if (!program.model) { if (ciInfo.isCI) { program.model = getPrefOrDefault("model"); @@ -218,7 +222,11 @@ export const askQuestions = async ( } } - if (program.framework === "express" || program.framework === "nextjs") { + if ( + program.framework === "express" || + program.framework === "nextjs" || + program.framework === "fastapi" + ) { if (!program.engine) { if (ciInfo.isCI) { program.engine = getPrefOrDefault("engine"); @@ -243,7 +251,11 @@ export const askQuestions = async ( preferences.engine = engine; } } - if (program.engine !== "simple" && !program.vectorDb) { + if ( + program.engine !== "simple" && + !program.vectorDb && + program.framework !== "fastapi" + ) { if (ciInfo.isCI) { program.vectorDb = getPrefOrDefault("vectorDb"); } else { diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/constants.py b/packages/create-llama/templates/components/vectordbs/python/none/constants.py similarity index 100% rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/constants.py rename to packages/create-llama/templates/components/vectordbs/python/none/constants.py diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/context.py b/packages/create-llama/templates/components/vectordbs/python/none/context.py similarity index 100% rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/context.py rename to packages/create-llama/templates/components/vectordbs/python/none/context.py diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/generate.py b/packages/create-llama/templates/components/vectordbs/python/none/generate.py similarity index 100% rename from packages/create-llama/templates/types/streaming/fastapi/app/engine/generate.py rename to packages/create-llama/templates/components/vectordbs/python/none/generate.py diff --git a/packages/create-llama/templates/components/vectordbs/python/none/index.py b/packages/create-llama/templates/components/vectordbs/python/none/index.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7d36030c9485b359e2c23c855fd6e2ddc90fef --- /dev/null +++ b/packages/create-llama/templates/components/vectordbs/python/none/index.py @@ -0,0 +1,25 @@ +import logging +import os +from llama_index import ( + StorageContext, + load_index_from_storage, +) + +from app.engine.constants import STORAGE_DIR +from app.engine.context import create_service_context + + +def get_chat_engine(): + service_context = create_service_context() + # check if storage already exists + if not os.path.exists(STORAGE_DIR): + raise Exception( + "StorageContext is empty - call 'npm run generate' to generate the storage first" + ) + logger = logging.getLogger("uvicorn") + # load the existing index + logger.info(f"Loading index from {STORAGE_DIR}...") + storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) + index = load_index_from_storage(storage_context, service_context=service_context) + logger.info(f"Finished loading index from {STORAGE_DIR}") + return index.as_chat_engine() diff --git a/packages/create-llama/templates/index.ts b/packages/create-llama/templates/index.ts index 2757f002c1c44e140d35b7b38bb5ae9591606a26..c9af6c180f8edf23ec8db05774ed9cefa858e785 100644 --- a/packages/create-llama/templates/index.ts +++ b/packages/create-llama/templates/index.ts @@ -311,7 +311,8 @@ const installPythonTemplate = async ({ root, template, framework, -}: Pick<InstallTemplateArgs, "root" | "framework" | "template">) => { + engine, +}: Pick<InstallTemplateArgs, "root" | "framework" | "template" | "engine">) => { console.log("\nInitializing Python project with template:", template, "\n"); const templatePath = path.join(__dirname, "types", template, framework); await copy("**", root, { @@ -334,6 +335,15 @@ const installPythonTemplate = async ({ }, }); + if (engine === "context") { + const compPath = path.join(__dirname, "components"); + const VectorDBPath = path.join(compPath, "vectordbs", "python", "none"); + await copy("**", path.join(root, "app", "engine"), { + parents: true, + cwd: VectorDBPath, + }); + } + console.log( "\nPython project, dependencies won't be installed automatically.\n", ); diff --git a/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py b/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py index e728db0dc3d07dfc13b09468235c94a903d3ce7d..b90820550292a0c59f6c22e62f0303b450d6678f 100644 --- a/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py +++ b/packages/create-llama/templates/types/simple/fastapi/app/api/routers/chat.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from llama_index import VectorStoreIndex from llama_index.llms.base import MessageRole, ChatMessage from pydantic import BaseModel -from app.context import get_index +from app.engine.index import get_chat_engine chat_router = r = APIRouter() @@ -25,7 +25,7 @@ class _Result(BaseModel): @r.post("") async def chat( data: _ChatData, - index: VectorStoreIndex = Depends(get_index), + chat_engine: VectorStoreIndex = Depends(get_chat_engine), ) -> _Result: # check preconditions and get last message if len(data.messages) == 0: @@ -49,7 +49,6 @@ async def chat( ] # query chat engine - chat_engine = index.as_chat_engine() response = chat_engine.chat(lastMessage.content, messages) return _Result( result=_Message(role=MessageRole.ASSISTANT, content=response.response) diff --git a/packages/create-llama/templates/types/simple/fastapi/app/context.py b/packages/create-llama/templates/types/simple/fastapi/app/engine/index.py similarity index 94% rename from packages/create-llama/templates/types/simple/fastapi/app/context.py rename to packages/create-llama/templates/types/simple/fastapi/app/engine/index.py index 48ca79a90c57bcd14f06f32642c903fa814703b1..8a5bfa5c583d98fe43f070e4e002a4fc34401454 100644 --- a/packages/create-llama/templates/types/simple/fastapi/app/context.py +++ b/packages/create-llama/templates/types/simple/fastapi/app/engine/index.py @@ -38,3 +38,7 @@ def get_index(): index = load_index_from_storage(storage_context,service_context=service_context) logger.info(f"Finished loading index from {STORAGE_DIR}") return index + +def get_chat_engine(): + index = get_index() + return index.as_chat_engine() diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/context.py b/packages/create-llama/templates/types/streaming/fastapi/app/context.py deleted file mode 100644 index ae00de217c8741e080c981cc3fed21f24fe19961..0000000000000000000000000000000000000000 --- a/packages/create-llama/templates/types/streaming/fastapi/app/context.py +++ /dev/null @@ -1,11 +0,0 @@ -import os - -from llama_index import ServiceContext -from llama_index.llms import OpenAI - - -def create_base_context(): - model = os.getenv("MODEL", "gpt-3.5-turbo") - return ServiceContext.from_defaults( - llm=OpenAI(model=model), - ) diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/__init__.py b/packages/create-llama/templates/types/streaming/fastapi/app/engine/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py b/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py index 8f7d36030c9485b359e2c23c855fd6e2ddc90fef..8a5bfa5c583d98fe43f070e4e002a4fc34401454 100644 --- a/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py +++ b/packages/create-llama/templates/types/streaming/fastapi/app/engine/index.py @@ -1,25 +1,44 @@ -import logging import os +import logging + from llama_index import ( + SimpleDirectoryReader, StorageContext, + VectorStoreIndex, load_index_from_storage, + ServiceContext, ) +from llama_index.llms import OpenAI -from app.engine.constants import STORAGE_DIR -from app.engine.context import create_service_context +STORAGE_DIR = "./storage" # directory to cache the generated index +DATA_DIR = "./data" # directory containing the documents to index +def create_base_context(): + model = os.getenv("MODEL", "gpt-3.5-turbo") + return ServiceContext.from_defaults( + llm=OpenAI(model=model), + ) -def get_chat_engine(): - service_context = create_service_context() +def get_index(): + service_context = create_base_context() + logger = logging.getLogger("uvicorn") # check if storage already exists if not os.path.exists(STORAGE_DIR): - raise Exception( - "StorageContext is empty - call 'npm run generate' to generate the storage first" - ) - logger = logging.getLogger("uvicorn") - # load the existing index - logger.info(f"Loading index from {STORAGE_DIR}...") - storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) - index = load_index_from_storage(storage_context, service_context=service_context) - logger.info(f"Finished loading index from {STORAGE_DIR}") + logger.info("Creating new index") + # load the documents and create the index + documents = SimpleDirectoryReader(DATA_DIR).load_data() + index = VectorStoreIndex.from_documents(documents,service_context=service_context) + # store it for later + index.storage_context.persist(STORAGE_DIR) + logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}") + else: + # load the existing index + logger.info(f"Loading index from {STORAGE_DIR}...") + storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) + index = load_index_from_storage(storage_context,service_context=service_context) + logger.info(f"Finished loading index from {STORAGE_DIR}") + return index + +def get_chat_engine(): + index = get_index() return index.as_chat_engine()