Skip to content
Snippets Groups Projects
Commit 2a456980 authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

fix: modify streaming fastapi to support vercel/ai

parent 8cdea3cb
No related branches found
No related tags found
No related merge requests found
Showing with 167 additions and 87 deletions
...@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi ...@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi
Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`:
```
ENVIRONMENT=prod uvicorn main:app
```
## Learn More ## Learn More
To learn more about LlamaIndex, take a look at the following resources: To learn more about LlamaIndex, take a look at the following resources:
......
import logging
import os
from typing import List from typing import List
from app.utils.index import get_index
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from llama_index import ( from llama_index import VectorStoreIndex
StorageContext, from llama_index.llms.base import MessageRole, ChatMessage
load_index_from_storage,
SimpleDirectoryReader,
VectorStoreIndex,
)
from llama_index.llms.base import MessageRole
from pydantic import BaseModel from pydantic import BaseModel
STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index
chat_router = r = APIRouter() chat_router = r = APIRouter()
class Message(BaseModel): class _Message(BaseModel):
role: MessageRole role: MessageRole
content: str content: str
class _ChatData(BaseModel): class _ChatData(BaseModel):
messages: List[Message] messages: List[_Message]
def get_index():
logger = logging.getLogger("uvicorn")
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
logger.info("Creating new index")
# load the documents and create the index
documents = SimpleDirectoryReader(DATA_DIR).load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(STORAGE_DIR)
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
else:
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {STORAGE_DIR}")
return index
@r.post("")
@r.post("/") async def chat(
def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Message: data: _ChatData,
# check preconditions index: VectorStoreIndex = Depends(get_index),
) -> _Message:
# check preconditions and get last message
if len(data.messages) == 0: if len(data.messages) == 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
...@@ -60,6 +35,16 @@ def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Messa ...@@ -60,6 +35,16 @@ def chat(data: _ChatData, index: VectorStoreIndex = Depends(get_index)) -> Messa
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user", detail="Last message must be from user",
) )
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
# query chat engine
chat_engine = index.as_chat_engine() chat_engine = index.as_chat_engine()
response = chat_engine.chat(lastMessage.content, data.messages) response = chat_engine.chat(lastMessage.content, messages)
return Message(role=MessageRole.ASSISTANT, content=response.response) return _Message(role=MessageRole.ASSISTANT, content=response.response)
import logging
import os
from llama_index import (
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index
def get_index():
logger = logging.getLogger("uvicorn")
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
logger.info("Creating new index")
# load the documents and create the index
documents = SimpleDirectoryReader(DATA_DIR).load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(STORAGE_DIR)
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
else:
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {STORAGE_DIR}")
return index
import logging
import os import os
import uvicorn import uvicorn
from app.api.routers.chat import chat_router from app.api.routers.chat import chat_router
...@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware
app = FastAPI() app = FastAPI()
origin = os.getenv("CORS_ORIGIN") environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if origin:
if environment == "dev":
logger = logging.getLogger("uvicorn")
logger.warning("Running in development mode - allowing CORS for all origins")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[origin], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
......
...@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi ...@@ -27,6 +27,12 @@ You can start editing the API by modifying `app/api/routers/chat.py`. The endpoi
Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`:
```
ENVIRONMENT=prod uvicorn main:app
```
## Learn More ## Learn More
To learn more about LlamaIndex, take a look at the following resources: To learn more about LlamaIndex, take a look at the following resources:
......
import logging
import os
from typing import List from typing import List
from fastapi.responses import StreamingResponse
from app.utils.json import json_to_model
from app.utils.index import get_index
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index import ( from llama_index import VectorStoreIndex
StorageContext, from llama_index.llms.base import MessageRole, ChatMessage
load_index_from_storage,
SimpleDirectoryReader,
VectorStoreIndex,
)
from llama_index.llms.base import MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index
chat_router = r = APIRouter() chat_router = r = APIRouter()
class Message(BaseModel): class _Message(BaseModel):
role: MessageRole role: MessageRole
content: str content: str
class _ChatData(BaseModel): class _ChatData(BaseModel):
messages: List[Message] messages: List[_Message]
def get_index(): @r.post("")
logger = logging.getLogger("uvicorn")
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
logger.info("Creating new index")
# load the documents and create the index
documents = SimpleDirectoryReader(DATA_DIR).load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(STORAGE_DIR)
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
else:
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {STORAGE_DIR}")
return index
@r.post("/")
async def chat( async def chat(
request: Request, data: _ChatData, index: VectorStoreIndex = Depends(get_index) request: Request,
) -> Message: # Note: To support clients sending a JSON object using content-type "text/plain",
# check preconditions # we need to use Depends(json_to_model(_ChatData)) here
data: _ChatData = Depends(json_to_model(_ChatData)),
index: VectorStoreIndex = Depends(get_index),
):
# check preconditions and get last message
if len(data.messages) == 0: if len(data.messages) == 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
...@@ -63,10 +41,18 @@ async def chat( ...@@ -63,10 +41,18 @@ async def chat(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user", detail="Last message must be from user",
) )
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
# query chat engine # query chat engine
chat_engine = index.as_chat_engine() chat_engine = index.as_chat_engine()
response = chat_engine.stream_chat(lastMessage.content, data.messages) response = chat_engine.stream_chat(lastMessage.content, messages)
# stream response # stream response
async def event_generator(): async def event_generator():
...@@ -76,4 +62,4 @@ async def chat( ...@@ -76,4 +62,4 @@ async def chat(
break break
yield token yield token
return EventSourceResponse(event_generator()) return StreamingResponse(event_generator(), media_type="text/plain")
import logging
import os
from llama_index import (
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index
def get_index():
logger = logging.getLogger("uvicorn")
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
logger.info("Creating new index")
# load the documents and create the index
documents = SimpleDirectoryReader(DATA_DIR).load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(STORAGE_DIR)
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
else:
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {STORAGE_DIR}")
return index
import json
from typing import TypeVar
from fastapi import HTTPException, Request
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
def json_to_model(cls: T):
async def get_json(request: Request) -> T:
body = await request.body()
try:
data_dict = json.loads(body.decode("utf-8"))
return cls(**data_dict)
except (json.JSONDecodeError, ValidationError) as e:
raise HTTPException(
status_code=400, detail=f"Could not decode JSON: {str(e)}"
)
return get_json
import logging
import os import os
import uvicorn import uvicorn
from app.api.routers.chat import chat_router from app.api.routers.chat import chat_router
...@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -6,11 +7,15 @@ from fastapi.middleware.cors import CORSMiddleware
app = FastAPI() app = FastAPI()
origin = os.getenv("CORS_ORIGIN") environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if origin:
if environment == "dev":
logger = logging.getLogger("uvicorn")
logger.warning("Running in development mode - allowing CORS for all origins")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[origin], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
......
...@@ -11,7 +11,6 @@ fastapi = "^0.104.1" ...@@ -11,7 +11,6 @@ fastapi = "^0.104.1"
uvicorn = { extras = ["standard"], version = "^0.23.2" } uvicorn = { extras = ["standard"], version = "^0.23.2" }
llama-index = "^0.8.56" llama-index = "^0.8.56"
pypdf = "^3.17.0" pypdf = "^3.17.0"
sse-starlette = "^1.6.5"
[build-system] [build-system]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment