diff --git a/templates/simple/fastapi/README-template.md b/templates/simple/fastapi/README-template.md index baa5fa63fcb1c07f8d74af3aa2eabd7bb493fda2..f0b92bdfce648a374bd2723fee4ceec69605db69 100644 --- a/templates/simple/fastapi/README-template.md +++ b/templates/simple/fastapi/README-template.md @@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. +The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`: + +``` +ENVIRONMENT=prod uvicorn main:app +``` + ## Learn More To learn more about LlamaIndex, take a look at the following resources: diff --git a/templates/simple/fastapi/app/api/routers/chat.py b/templates/simple/fastapi/app/api/routers/chat.py index bd6a38c515df2f612ec2665e201329f950da80be..2d20a6f65aed631a6d2e2e6163a031456463a7f6 100644 --- a/templates/simple/fastapi/app/api/routers/chat.py +++ b/templates/simple/fastapi/app/api/routers/chat.py @@ -1,54 +1,29 @@ -import logging -import os from typing import List + +from app.utils.index import get_index from fastapi import APIRouter, Depends, HTTPException, status -from llama_index import ( - StorageContext, - load_index_from_storage, - SimpleDirectoryReader, - VectorStoreIndex, -) -from llama_index.llms.base import MessageRole +from llama_index import VectorStoreIndex +from llama_index.llms.base import MessageRole, ChatMessage from pydantic import BaseModel -STORAGE_DIR = "./storage" # directory to cache the generated index -DATA_DIR = "./data" # directory containing the documents to index - chat_router = r = APIRouter() -class Message(BaseModel): +class _Message(BaseModel): role: MessageRole content: str class _ChatData(BaseModel): - messages: List[Message] - + messages: List[_Message] -def get_index(): - logger = logging.getLogger("uvicorn") - # check if storage already exists - if not os.path.exists(STORAGE_DIR): - logger.info("Creating new index") - # load the documents and create the index - documents = SimpleDirectoryReader(DATA_DIR).load_data() - index = VectorStoreIndex.from_documents(documents) - # store it for later - index.storage_context.persist(STORAGE_DIR) - logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}") - else: - # load the existing index - logger.info(f"Loading index from {STORAGE_DIR}...") - storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) - index = load_index_from_storage(storage_context) - logger.info(f"Finished loading index from {STORAGE_DIR}") - return index - -@r.post("/") -def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Message: - # check preconditions +@r.post("") +async def chat( + data: _ChatData, + index: VectorStoreIndex = Depends(get_index), +) -> _Message: + # check preconditions and get last message if len(data.messages) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -60,6 +35,16 @@ def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Messa status_code=status.HTTP_400_BAD_REQUEST, detail="Last message must be from user", ) + # convert messages coming from the request to type ChatMessage + messages = [ + ChatMessage( + role=m.role, + content=m.content, + ) + for m in data.messages + ] + + # query chat engine chat_engine = index.as_chat_engine() - response = chat_engine.chat(lastMessage.content, data.messages) - return Message(role=MessageRole.ASSISTANT, content=response.response) + response = chat_engine.chat(lastMessage.content, messages) + return _Message(role=MessageRole.ASSISTANT, content=response.response) diff --git a/templates/simple/fastapi/app/utils/__init__.py b/templates/simple/fastapi/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/templates/simple/fastapi/app/utils/index.py b/templates/simple/fastapi/app/utils/index.py new file mode 100644 index 0000000000000000000000000000000000000000..076ca76631a6e0f752a420f0bc3f90286029796a --- /dev/null +++ b/templates/simple/fastapi/app/utils/index.py @@ -0,0 +1,33 @@ +import logging +import os + +from llama_index import ( + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, + load_index_from_storage, +) + + +STORAGE_DIR = "./storage" # directory to cache the generated index +DATA_DIR = "./data" # directory containing the documents to index + + +def get_index(): + logger = logging.getLogger("uvicorn") + # check if storage already exists + if not os.path.exists(STORAGE_DIR): + logger.info("Creating new index") + # load the documents and create the index + documents = SimpleDirectoryReader(DATA_DIR).load_data() + index = VectorStoreIndex.from_documents(documents) + # store it for later + index.storage_context.persist(STORAGE_DIR) + logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}") + else: + # load the existing index + logger.info(f"Loading index from {STORAGE_DIR}...") + storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) + index = load_index_from_storage(storage_context) + logger.info(f"Finished loading index from {STORAGE_DIR}") + return index diff --git a/templates/simple/fastapi/main.py b/templates/simple/fastapi/main.py index e307354bc3b935a52d96c0e138a937969df4d4cf..eaa9f0f259ad932e4527e0a9dcab03d4bb160fc4 100644 --- a/templates/simple/fastapi/main.py +++ b/templates/simple/fastapi/main.py @@ -1,3 +1,4 @@ +import logging import os import uvicorn from app.api.routers.chat import chat_router @@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware app = FastAPI() -origin = os.getenv("CORS_ORIGIN") -if origin: +environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set + + +if environment == "dev": + logger = logging.getLogger("uvicorn") + logger.warning("Running in development mode - allowing CORS for all origins") app.add_middleware( CORSMiddleware, - allow_origins=[origin], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/templates/streaming/fastapi/README-template.md b/templates/streaming/fastapi/README-template.md index baa5fa63fcb1c07f8d74af3aa2eabd7bb493fda2..f0b92bdfce648a374bd2723fee4ceec69605db69 100644 --- a/templates/streaming/fastapi/README-template.md +++ b/templates/streaming/fastapi/README-template.md @@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. +The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`: + +``` +ENVIRONMENT=prod uvicorn main:app +``` + ## Learn More To learn more about LlamaIndex, take a look at the following resources: diff --git a/templates/streaming/fastapi/app/api/routers/chat.py b/templates/streaming/fastapi/app/api/routers/chat.py index bc9b5ed651efff1a91f659b8130004b0b90be6a9..36b618e232a33fc882b3d4b42b885c29c64edd4e 100644 --- a/templates/streaming/fastapi/app/api/routers/chat.py +++ b/templates/streaming/fastapi/app/api/routers/chat.py @@ -1,57 +1,35 @@ -import logging -import os from typing import List + +from fastapi.responses import StreamingResponse + +from app.utils.json import json_to_model +from app.utils.index import get_index from fastapi import APIRouter, Depends, HTTPException, Request, status -from llama_index import ( - StorageContext, - load_index_from_storage, - SimpleDirectoryReader, - VectorStoreIndex, -) -from llama_index.llms.base import MessageRole +from llama_index import VectorStoreIndex +from llama_index.llms.base import MessageRole, ChatMessage from pydantic import BaseModel -from sse_starlette.sse import EventSourceResponse - -STORAGE_DIR = "./storage" # directory to cache the generated index -DATA_DIR = "./data" # directory containing the documents to index chat_router = r = APIRouter() -class Message(BaseModel): +class _Message(BaseModel): role: MessageRole content: str class _ChatData(BaseModel): - messages: List[Message] + messages: List[_Message] -def get_index(): - logger = logging.getLogger("uvicorn") - # check if storage already exists - if not os.path.exists(STORAGE_DIR): - logger.info("Creating new index") - # load the documents and create the index - documents = SimpleDirectoryReader(DATA_DIR).load_data() - index = VectorStoreIndex.from_documents(documents) - # store it for later - index.storage_context.persist(STORAGE_DIR) - logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}") - else: - # load the existing index - logger.info(f"Loading index from {STORAGE_DIR}...") - storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) - index = load_index_from_storage(storage_context) - logger.info(f"Finished loading index from {STORAGE_DIR}") - return index - - -@r.post("/") +@r.post("") async def chat( - request: Request, data: _ChatData, index: VectorStoreIndex = Depends(get_index) -) -> Message: - # check preconditions + request: Request, + # Note: To support clients sending a JSON object using content-type "text/plain", + # we need to use Depends(json_to_model(_ChatData)) here + data: _ChatData = Depends(json_to_model(_ChatData)), + index: VectorStoreIndex = Depends(get_index), +): + # check preconditions and get last message if len(data.messages) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -63,10 +41,18 @@ async def chat( status_code=status.HTTP_400_BAD_REQUEST, detail="Last message must be from user", ) + # convert messages coming from the request to type ChatMessage + messages = [ + ChatMessage( + role=m.role, + content=m.content, + ) + for m in data.messages + ] # query chat engine chat_engine = index.as_chat_engine() - response = chat_engine.stream_chat(lastMessage.content, data.messages) + response = chat_engine.stream_chat(lastMessage.content, messages) # stream response async def event_generator(): @@ -76,4 +62,4 @@ async def chat( break yield token - return EventSourceResponse(event_generator()) + return StreamingResponse(event_generator(), media_type="text/plain") diff --git a/templates/streaming/fastapi/app/utils/__init__.py b/templates/streaming/fastapi/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/templates/streaming/fastapi/app/utils/index.py b/templates/streaming/fastapi/app/utils/index.py new file mode 100644 index 0000000000000000000000000000000000000000..076ca76631a6e0f752a420f0bc3f90286029796a --- /dev/null +++ b/templates/streaming/fastapi/app/utils/index.py @@ -0,0 +1,33 @@ +import logging +import os + +from llama_index import ( + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, + load_index_from_storage, +) + + +STORAGE_DIR = "./storage" # directory to cache the generated index +DATA_DIR = "./data" # directory containing the documents to index + + +def get_index(): + logger = logging.getLogger("uvicorn") + # check if storage already exists + if not os.path.exists(STORAGE_DIR): + logger.info("Creating new index") + # load the documents and create the index + documents = SimpleDirectoryReader(DATA_DIR).load_data() + index = VectorStoreIndex.from_documents(documents) + # store it for later + index.storage_context.persist(STORAGE_DIR) + logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}") + else: + # load the existing index + logger.info(f"Loading index from {STORAGE_DIR}...") + storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR) + index = load_index_from_storage(storage_context) + logger.info(f"Finished loading index from {STORAGE_DIR}") + return index diff --git a/templates/streaming/fastapi/app/utils/json.py b/templates/streaming/fastapi/app/utils/json.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a847f53e107f665389f11ec005795e0fb8c5b3 --- /dev/null +++ b/templates/streaming/fastapi/app/utils/json.py @@ -0,0 +1,22 @@ +import json +from typing import TypeVar +from fastapi import HTTPException, Request + +from pydantic import BaseModel, ValidationError + + +T = TypeVar("T", bound=BaseModel) + + +def json_to_model(cls: T): + async def get_json(request: Request) -> T: + body = await request.body() + try: + data_dict = json.loads(body.decode("utf-8")) + return cls(**data_dict) + except (json.JSONDecodeError, ValidationError) as e: + raise HTTPException( + status_code=400, detail=f"Could not decode JSON: {str(e)}" + ) + + return get_json diff --git a/templates/streaming/fastapi/main.py b/templates/streaming/fastapi/main.py index e307354bc3b935a52d96c0e138a937969df4d4cf..eaa9f0f259ad932e4527e0a9dcab03d4bb160fc4 100644 --- a/templates/streaming/fastapi/main.py +++ b/templates/streaming/fastapi/main.py @@ -1,3 +1,4 @@ +import logging import os import uvicorn from app.api.routers.chat import chat_router @@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware app = FastAPI() -origin = os.getenv("CORS_ORIGIN") -if origin: +environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set + + +if environment == "dev": + logger = logging.getLogger("uvicorn") + logger.warning("Running in development mode - allowing CORS for all origins") app.add_middleware( CORSMiddleware, - allow_origins=[origin], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/templates/streaming/fastapi/pyproject.toml b/templates/streaming/fastapi/pyproject.toml index 73d3cc51070a81f764999100cc7d7ec1df8f36c8..ae0720ecd43f221f2f8ee9758d6e4d14bb8fb554 100644 --- a/templates/streaming/fastapi/pyproject.toml +++ b/templates/streaming/fastapi/pyproject.toml @@ -11,7 +11,6 @@ fastapi = "^0.104.1" uvicorn = { extras = ["standard"], version = "^0.23.2" } llama-index = "^0.8.56" pypdf = "^3.17.0" -sse-starlette = "^1.6.5" [build-system]