From 2a4569808e9501b17fffe3d3f294cc471a60be39 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Thu, 2 Nov 2023 12:51:43 +0700
Subject: [PATCH] fix: modify streaming fastapi to support vercel/ai

---
 templates/simple/fastapi/README-template.md   |  6 ++
 .../simple/fastapi/app/api/routers/chat.py    | 63 +++++++----------
 .../simple/fastapi/app/utils/__init__.py      |  0
 templates/simple/fastapi/app/utils/index.py   | 33 +++++++++
 templates/simple/fastapi/main.py              | 11 ++-
 .../streaming/fastapi/README-template.md      |  6 ++
 .../streaming/fastapi/app/api/routers/chat.py | 68 ++++++++-----------
 .../streaming/fastapi/app/utils/__init__.py   |  0
 .../streaming/fastapi/app/utils/index.py      | 33 +++++++++
 templates/streaming/fastapi/app/utils/json.py | 22 ++++++
 templates/streaming/fastapi/main.py           | 11 ++-
 templates/streaming/fastapi/pyproject.toml    |  1 -
 12 files changed, 167 insertions(+), 87 deletions(-)
 create mode 100644 templates/simple/fastapi/app/utils/__init__.py
 create mode 100644 templates/simple/fastapi/app/utils/index.py
 create mode 100644 templates/streaming/fastapi/app/utils/__init__.py
 create mode 100644 templates/streaming/fastapi/app/utils/index.py
 create mode 100644 templates/streaming/fastapi/app/utils/json.py

diff --git a/templates/simple/fastapi/README-template.md b/templates/simple/fastapi/README-template.md
index baa5fa63..f0b92bdf 100644
--- a/templates/simple/fastapi/README-template.md
+++ b/templates/simple/fastapi/README-template.md
@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi
 
 Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
 
+The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`:
+
+```
+ENVIRONMENT=prod uvicorn main:app
+```
+
 ## Learn More
 
 To learn more about LlamaIndex, take a look at the following resources:
diff --git a/templates/simple/fastapi/app/api/routers/chat.py b/templates/simple/fastapi/app/api/routers/chat.py
index bd6a38c5..2d20a6f6 100644
--- a/templates/simple/fastapi/app/api/routers/chat.py
+++ b/templates/simple/fastapi/app/api/routers/chat.py
@@ -1,54 +1,29 @@
-import logging
-import os
 from typing import List
+
+from app.utils.index import get_index
 from fastapi import APIRouter, Depends, HTTPException, status
-from llama_index import (
-    StorageContext,
-    load_index_from_storage,
-    SimpleDirectoryReader,
-    VectorStoreIndex,
-)
-from llama_index.llms.base import MessageRole
+from llama_index import VectorStoreIndex
+from llama_index.llms.base import MessageRole, ChatMessage
 from pydantic import BaseModel
 
-STORAGE_DIR = "./storage"  # directory to cache the generated index
-DATA_DIR = "./data"  # directory containing the documents to index
-
 chat_router = r = APIRouter()
 
 
-class Message(BaseModel):
+class _Message(BaseModel):
     role: MessageRole
     content: str
 
 
 class _ChatData(BaseModel):
-    messages: List[Message]
-
+    messages: List[_Message]
 
-def get_index():
-    logger = logging.getLogger("uvicorn")
-    # check if storage already exists
-    if not os.path.exists(STORAGE_DIR):
-        logger.info("Creating new index")
-        # load the documents and create the index
-        documents = SimpleDirectoryReader(DATA_DIR).load_data()
-        index = VectorStoreIndex.from_documents(documents)
-        # store it for later
-        index.storage_context.persist(STORAGE_DIR)
-        logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
-    else:
-        # load the existing index
-        logger.info(f"Loading index from {STORAGE_DIR}...")
-        storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
-        index = load_index_from_storage(storage_context)
-        logger.info(f"Finished loading index from {STORAGE_DIR}")
-    return index
 
-
-@r.post("/")
-def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Message:
-    # check preconditions
+@r.post("")
+async def chat(
+    data: _ChatData,
+    index: VectorStoreIndex = Depends(get_index),
+) -> _Message:
+    # check preconditions and get last message
     if len(data.messages) == 0:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -60,6 +35,16 @@ def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Messa
             status_code=status.HTTP_400_BAD_REQUEST,
             detail="Last message must be from user",
         )
+    # convert messages coming from the request to type ChatMessage
+    messages = [
+        ChatMessage(
+            role=m.role,
+            content=m.content,
+        )
+        for m in data.messages
+    ]
+
+    # query chat engine
     chat_engine = index.as_chat_engine()
-    response = chat_engine.chat(lastMessage.content, data.messages)
-    return Message(role=MessageRole.ASSISTANT, content=response.response)
+    response = chat_engine.chat(lastMessage.content, messages)
+    return _Message(role=MessageRole.ASSISTANT, content=response.response)
diff --git a/templates/simple/fastapi/app/utils/__init__.py b/templates/simple/fastapi/app/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/templates/simple/fastapi/app/utils/index.py b/templates/simple/fastapi/app/utils/index.py
new file mode 100644
index 00000000..076ca766
--- /dev/null
+++ b/templates/simple/fastapi/app/utils/index.py
@@ -0,0 +1,33 @@
+import logging
+import os
+
+from llama_index import (
+    SimpleDirectoryReader,
+    StorageContext,
+    VectorStoreIndex,
+    load_index_from_storage,
+)
+
+
+STORAGE_DIR = "./storage"  # directory to cache the generated index
+DATA_DIR = "./data"  # directory containing the documents to index
+
+
+def get_index():
+    logger = logging.getLogger("uvicorn")
+    # check if storage already exists
+    if not os.path.exists(STORAGE_DIR):
+        logger.info("Creating new index")
+        # load the documents and create the index
+        documents = SimpleDirectoryReader(DATA_DIR).load_data()
+        index = VectorStoreIndex.from_documents(documents)
+        # store it for later
+        index.storage_context.persist(STORAGE_DIR)
+        logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
+    else:
+        # load the existing index
+        logger.info(f"Loading index from {STORAGE_DIR}...")
+        storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
+        index = load_index_from_storage(storage_context)
+        logger.info(f"Finished loading index from {STORAGE_DIR}")
+    return index
diff --git a/templates/simple/fastapi/main.py b/templates/simple/fastapi/main.py
index e307354b..eaa9f0f2 100644
--- a/templates/simple/fastapi/main.py
+++ b/templates/simple/fastapi/main.py
@@ -1,3 +1,4 @@
+import logging
 import os
 import uvicorn
 from app.api.routers.chat import chat_router
@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware
 
 app = FastAPI()
 
-origin = os.getenv("CORS_ORIGIN")
-if origin:
+environment = os.getenv("ENVIRONMENT", "dev")  # Default to 'development' if not set
+
+
+if environment == "dev":
+    logger = logging.getLogger("uvicorn")
+    logger.warning("Running in development mode - allowing CORS for all origins")
     app.add_middleware(
         CORSMiddleware,
-        allow_origins=[origin],
+        allow_origins=["*"],
         allow_credentials=True,
         allow_methods=["*"],
         allow_headers=["*"],
diff --git a/templates/streaming/fastapi/README-template.md b/templates/streaming/fastapi/README-template.md
index baa5fa63..f0b92bdf 100644
--- a/templates/streaming/fastapi/README-template.md
+++ b/templates/streaming/fastapi/README-template.md
@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi
 
 Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
 
+The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`:
+
+```
+ENVIRONMENT=prod uvicorn main:app
+```
+
 ## Learn More
 
 To learn more about LlamaIndex, take a look at the following resources:
diff --git a/templates/streaming/fastapi/app/api/routers/chat.py b/templates/streaming/fastapi/app/api/routers/chat.py
index bc9b5ed6..36b618e2 100644
--- a/templates/streaming/fastapi/app/api/routers/chat.py
+++ b/templates/streaming/fastapi/app/api/routers/chat.py
@@ -1,57 +1,35 @@
-import logging
-import os
 from typing import List
+
+from fastapi.responses import StreamingResponse
+
+from app.utils.json import json_to_model
+from app.utils.index import get_index
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from llama_index import (
-    StorageContext,
-    load_index_from_storage,
-    SimpleDirectoryReader,
-    VectorStoreIndex,
-)
-from llama_index.llms.base import MessageRole
+from llama_index import VectorStoreIndex
+from llama_index.llms.base import MessageRole, ChatMessage
 from pydantic import BaseModel
-from sse_starlette.sse import EventSourceResponse
-
-STORAGE_DIR = "./storage"  # directory to cache the generated index
-DATA_DIR = "./data"  # directory containing the documents to index
 
 chat_router = r = APIRouter()
 
 
-class Message(BaseModel):
+class _Message(BaseModel):
     role: MessageRole
     content: str
 
 
 class _ChatData(BaseModel):
-    messages: List[Message]
+    messages: List[_Message]
 
 
-def get_index():
-    logger = logging.getLogger("uvicorn")
-    # check if storage already exists
-    if not os.path.exists(STORAGE_DIR):
-        logger.info("Creating new index")
-        # load the documents and create the index
-        documents = SimpleDirectoryReader(DATA_DIR).load_data()
-        index = VectorStoreIndex.from_documents(documents)
-        # store it for later
-        index.storage_context.persist(STORAGE_DIR)
-        logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
-    else:
-        # load the existing index
-        logger.info(f"Loading index from {STORAGE_DIR}...")
-        storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
-        index = load_index_from_storage(storage_context)
-        logger.info(f"Finished loading index from {STORAGE_DIR}")
-    return index
-
-
-@r.post("/")
+@r.post("")
 async def chat(
-    request: Request, data: _ChatData, index: VectorStoreIndex = Depends(get_index)
-) -> Message:
-    # check preconditions
+    request: Request,
+    # Note: To support clients sending a JSON object using content-type "text/plain",
+    # we need to use Depends(json_to_model(_ChatData)) here
+    data: _ChatData = Depends(json_to_model(_ChatData)),
+    index: VectorStoreIndex = Depends(get_index),
+):
+    # check preconditions and get last message
     if len(data.messages) == 0:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -63,10 +41,18 @@ async def chat(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail="Last message must be from user",
         )
+    # convert messages coming from the request to type ChatMessage
+    messages = [
+        ChatMessage(
+            role=m.role,
+            content=m.content,
+        )
+        for m in data.messages
+    ]
 
     # query chat engine
     chat_engine = index.as_chat_engine()
-    response = chat_engine.stream_chat(lastMessage.content, data.messages)
+    response = chat_engine.stream_chat(lastMessage.content, messages)
 
     # stream response
     async def event_generator():
@@ -76,4 +62,4 @@ async def chat(
                 break
             yield token
 
-    return EventSourceResponse(event_generator())
+    return StreamingResponse(event_generator(), media_type="text/plain")
diff --git a/templates/streaming/fastapi/app/utils/__init__.py b/templates/streaming/fastapi/app/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/templates/streaming/fastapi/app/utils/index.py b/templates/streaming/fastapi/app/utils/index.py
new file mode 100644
index 00000000..076ca766
--- /dev/null
+++ b/templates/streaming/fastapi/app/utils/index.py
@@ -0,0 +1,33 @@
+import logging
+import os
+
+from llama_index import (
+    SimpleDirectoryReader,
+    StorageContext,
+    VectorStoreIndex,
+    load_index_from_storage,
+)
+
+
+STORAGE_DIR = "./storage"  # directory to cache the generated index
+DATA_DIR = "./data"  # directory containing the documents to index
+
+
+def get_index():
+    logger = logging.getLogger("uvicorn")
+    # check if storage already exists
+    if not os.path.exists(STORAGE_DIR):
+        logger.info("Creating new index")
+        # load the documents and create the index
+        documents = SimpleDirectoryReader(DATA_DIR).load_data()
+        index = VectorStoreIndex.from_documents(documents)
+        # store it for later
+        index.storage_context.persist(STORAGE_DIR)
+        logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
+    else:
+        # load the existing index
+        logger.info(f"Loading index from {STORAGE_DIR}...")
+        storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
+        index = load_index_from_storage(storage_context)
+        logger.info(f"Finished loading index from {STORAGE_DIR}")
+    return index
diff --git a/templates/streaming/fastapi/app/utils/json.py b/templates/streaming/fastapi/app/utils/json.py
new file mode 100644
index 00000000..d9a847f5
--- /dev/null
+++ b/templates/streaming/fastapi/app/utils/json.py
@@ -0,0 +1,22 @@
+import json
+from typing import TypeVar
+from fastapi import HTTPException, Request
+
+from pydantic import BaseModel, ValidationError
+
+
+T = TypeVar("T", bound=BaseModel)
+
+
+def json_to_model(cls: T):
+    async def get_json(request: Request) -> T:
+        body = await request.body()
+        try:
+            data_dict = json.loads(body.decode("utf-8"))
+            return cls(**data_dict)
+        except (json.JSONDecodeError, ValidationError) as e:
+            raise HTTPException(
+                status_code=400, detail=f"Could not decode JSON: {str(e)}"
+            )
+
+    return get_json
diff --git a/templates/streaming/fastapi/main.py b/templates/streaming/fastapi/main.py
index e307354b..eaa9f0f2 100644
--- a/templates/streaming/fastapi/main.py
+++ b/templates/streaming/fastapi/main.py
@@ -1,3 +1,4 @@
+import logging
 import os
 import uvicorn
 from app.api.routers.chat import chat_router
@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware
 
 app = FastAPI()
 
-origin = os.getenv("CORS_ORIGIN")
-if origin:
+environment = os.getenv("ENVIRONMENT", "dev")  # Default to 'development' if not set
+
+
+if environment == "dev":
+    logger = logging.getLogger("uvicorn")
+    logger.warning("Running in development mode - allowing CORS for all origins")
     app.add_middleware(
         CORSMiddleware,
-        allow_origins=[origin],
+        allow_origins=["*"],
         allow_credentials=True,
         allow_methods=["*"],
         allow_headers=["*"],
diff --git a/templates/streaming/fastapi/pyproject.toml b/templates/streaming/fastapi/pyproject.toml
index 73d3cc51..ae0720ec 100644
--- a/templates/streaming/fastapi/pyproject.toml
+++ b/templates/streaming/fastapi/pyproject.toml
@@ -11,7 +11,6 @@ fastapi = "^0.104.1"
 uvicorn = { extras = ["standard"], version = "^0.23.2" }
 llama-index = "^0.8.56"
 pypdf = "^3.17.0"
-sse-starlette = "^1.6.5"
 
 
 [build-system]
-- 
GitLab