Skip to content
Snippets Groups Projects
Unverified Commit 09f1db3b authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

feat: Support uploading CSV files for FastAPI app (#109)

parent cb3be7d1
No related branches found
No related tags found
No related merge requests found
import os import os
import logging import logging
from pydantic import BaseModel, Field
from pydantic.alias_generators import to_camel from aiostream import stream
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.schema import NodeWithScore from llama_index.core.llms import MessageRole
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse from app.api.routers.vercel_response import VercelStreamResponse
from app.api.routers.messaging import EventCallbackHandler from app.api.routers.events import EventCallbackHandler
from aiostream import stream from app.api.routers.models import (
ChatData,
ChatConfig,
SourceNodes,
Result,
Message,
)
chat_router = r = APIRouter() chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
class _Message(BaseModel):
role: MessageRole
content: str
class _ChatData(BaseModel):
messages: List[_Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
url: Optional[str]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
metadata = source_node.node.metadata
url = metadata.get("URL")
if not url:
file_name = metadata.get("file_name")
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables"
)
if file_name and url_prefix:
url = f"{url_prefix}/data/{file_name}"
return cls(
id=source_node.node.node_id,
metadata=metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
url=url,
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]
class _ChatConfig(BaseModel):
starter_questions: Optional[List[str]] = Field(
default=None,
description="List of starter questions",
)
class Config:
json_schema_extra = {
"example": {
"starterQuestions": [
"What standards for letters exist?",
"What are the requirements for a letter to be considered a letter?",
]
}
}
alias_generator = to_camel
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
# check preconditions and get last message
if len(data.messages) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No messages provided",
)
last_message = data.messages.pop()
if last_message.role != MessageRole.USER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user",
)
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
return last_message.content, messages
# streaming endpoint - delete if not needed # streaming endpoint - delete if not needed
@r.post("") @r.post("")
async def chat( async def chat(
request: Request, request: Request,
data: _ChatData, data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine), chat_engine: BaseChatEngine = Depends(get_chat_engine),
): ):
last_message_content, messages = await parse_chat_data(data) last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
event_handler = EventCallbackHandler() event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
try: try:
response = await chat_engine.astream_chat(last_message_content, messages) response = await chat_engine.astream_chat(last_message_content, messages)
except Exception as e: except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}", detail=f"Error in chat engine: {e}",
) ) from e
async def content_generator(): async def content_generator():
# Yield the additional data
if data.data is not None:
for data_response in data.get_additional_data_response():
yield VercelStreamResponse.convert_data(data_response)
# Yield the text response # Yield the text response
async def _text_generator(): async def _text_generator():
async for token in response.async_response_gen(): async for token in response.async_response_gen():
...@@ -167,7 +75,7 @@ async def chat( ...@@ -167,7 +75,7 @@ async def chat(
"type": "sources", "type": "sources",
"data": { "data": {
"nodes": [ "nodes": [
_SourceNodes.from_source_node(node).dict() SourceNodes.from_source_node(node).dict()
for node in response.source_nodes for node in response.source_nodes
] ]
}, },
...@@ -180,22 +88,23 @@ async def chat( ...@@ -180,22 +88,23 @@ async def chat(
# non-streaming endpoint - delete if not needed # non-streaming endpoint - delete if not needed
@r.post("/request") @r.post("/request")
async def chat_request( async def chat_request(
data: _ChatData, data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine), chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> _Result: ) -> Result:
last_message_content, messages = await parse_chat_data(data) last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
response = await chat_engine.achat(last_message_content, messages) response = await chat_engine.achat(last_message_content, messages)
return _Result( return Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response), result=Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes), nodes=SourceNodes.from_source_nodes(response.source_nodes),
) )
@r.get("/config") @r.get("/config")
async def chat_config() -> _ChatConfig: async def chat_config() -> ChatConfig:
starter_questions = None starter_questions = None
conversation_starters = os.getenv("CONVERSATION_STARTERS") conversation_starters = os.getenv("CONVERSATION_STARTERS")
if conversation_starters and conversation_starters.strip(): if conversation_starters and conversation_starters.strip():
starter_questions = conversation_starters.strip().split("\n") starter_questions = conversation_starters.strip().split("\n")
return _ChatConfig(starterQuestions=starter_questions) return ChatConfig(starterQuestions=starter_questions)
import os
import logging
from pydantic import BaseModel, Field, validator
from pydantic.alias_generators import to_camel
from typing import List, Any, Optional, Dict
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
logger = logging.getLogger("uvicorn")
class Message(BaseModel):
role: MessageRole
content: str
class CsvFile(BaseModel):
content: str
filename: str
filesize: int
id: str
type: str
class DataParserOptions(BaseModel):
csv_files: List[CsvFile] | None = Field(
default=None,
description="List of CSV files",
)
class Config:
json_schema_extra = {
"example": {
"csvFiles": [
{
"content": "Name, Age\nAlice, 25\nBob, 30",
"filename": "example.csv",
"filesize": 123,
"id": "123",
"type": "text/csv",
}
]
}
}
alias_generator = to_camel
def to_raw_content(self) -> str:
if self.csv_files is not None and len(self.csv_files) > 0:
return "Use data from following CSV raw contents" + "\n".join(
[f"```csv\n{csv_file.content}\n```" for csv_file in self.csv_files]
)
def to_response_data(self) -> list[dict] | None:
output = []
if self.csv_files is not None and len(self.csv_files) > 0:
output.append(
{
"type": "csv",
"data": {
"csvFiles": [csv_file.dict() for csv_file in self.csv_files]
},
}
)
return output if len(output) > 0 else None
class ChatData(BaseModel):
data: DataParserOptions | None = Field(
default=None,
)
messages: List[Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
@validator("messages")
def messages_must_not_be_empty(cls, v):
if len(v) == 0:
raise ValueError("Messages must not be empty")
return v
def get_last_message_content(self) -> str:
"""
Get the content of the last message along with the data content if available
"""
message_content = self.messages[-1].content
if self.data:
message_content += "\n" + self.data.to_raw_content()
return message_content
def get_history_messages(self) -> List[Message]:
"""
Get the history messages
"""
return [
ChatMessage(role=message.role, content=message.content)
for message in self.messages[:-1]
]
def get_additional_data_response(self) -> list[dict] | None:
"""
Get the additional data
"""
return self.data.to_response_data()
def is_last_message_from_user(self) -> bool:
return self.messages[-1].role == MessageRole.USER
class SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
url: Optional[str]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
metadata = source_node.node.metadata
url = metadata.get("URL")
if not url:
file_name = metadata.get("file_name")
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables"
)
if file_name and url_prefix:
url = f"{url_prefix}/data/{file_name}"
return cls(
id=source_node.node.node_id,
metadata=metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
url=url,
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class Result(BaseModel):
result: Message
nodes: List[SourceNodes]
class ChatConfig(BaseModel):
starter_questions: Optional[List[str]] = Field(
default=None,
description="List of starter questions",
)
class Config:
json_schema_extra = {
"example": {
"starterQuestions": [
"What standards for letters exist?",
"What are the requirements for a letter to be considered a letter?",
]
}
}
alias_generator = to_camel
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment