diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 39bf6b7981f23bc7e288c51688af2b700d77ac94..d6373cc9201b97081a41fc28e3125388f6534ff2 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,145 +1,53 @@ import os import logging -from pydantic import BaseModel, Field -from pydantic.alias_generators import to_camel -from typing import List, Any, Optional, Dict, Tuple + +from aiostream import stream from fastapi import APIRouter, Depends, HTTPException, Request, status from llama_index.core.chat_engine.types import BaseChatEngine -from llama_index.core.schema import NodeWithScore -from llama_index.core.llms import ChatMessage, MessageRole +from llama_index.core.llms import MessageRole from app.engine import get_chat_engine from app.api.routers.vercel_response import VercelStreamResponse -from app.api.routers.messaging import EventCallbackHandler -from aiostream import stream +from app.api.routers.events import EventCallbackHandler +from app.api.routers.models import ( + ChatData, + ChatConfig, + SourceNodes, + Result, + Message, +) chat_router = r = APIRouter() logger = logging.getLogger("uvicorn") -class _Message(BaseModel): - role: MessageRole - content: str - - -class _ChatData(BaseModel): - messages: List[_Message] - - class Config: - json_schema_extra = { - "example": { - "messages": [ - { - "role": "user", - "content": "What standards for letters exist?", - } - ] - } - } - - -class _SourceNodes(BaseModel): - id: str - metadata: Dict[str, Any] - score: Optional[float] - text: str - url: Optional[str] - - @classmethod - def from_source_node(cls, source_node: NodeWithScore): - metadata = source_node.node.metadata - url = metadata.get("URL") - - if not url: - file_name = metadata.get("file_name") - url_prefix = os.getenv("FILESERVER_URL_PREFIX") - if not url_prefix: - logger.warning( - "Warning: FILESERVER_URL_PREFIX not set in environment variables" - ) - if file_name and url_prefix: - url = f"{url_prefix}/data/{file_name}" - - return cls( - id=source_node.node.node_id, - metadata=metadata, - score=source_node.score, - text=source_node.node.text, # type: ignore - url=url, - ) - - @classmethod - def from_source_nodes(cls, source_nodes: List[NodeWithScore]): - return [cls.from_source_node(node) for node in source_nodes] - - -class _Result(BaseModel): - result: _Message - nodes: List[_SourceNodes] - - -class _ChatConfig(BaseModel): - starter_questions: Optional[List[str]] = Field( - default=None, - description="List of starter questions", - ) - - class Config: - json_schema_extra = { - "example": { - "starterQuestions": [ - "What standards for letters exist?", - "What are the requirements for a letter to be considered a letter?", - ] - } - } - alias_generator = to_camel - - -async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]: - # check preconditions and get last message - if len(data.messages) == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No messages provided", - ) - last_message = data.messages.pop() - if last_message.role != MessageRole.USER: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Last message must be from user", - ) - # convert messages coming from the request to type ChatMessage - messages = [ - ChatMessage( - role=m.role, - content=m.content, - ) - for m in data.messages - ] - return last_message.content, messages - - # streaming endpoint - delete if not needed @r.post("") async def chat( request: Request, - data: _ChatData, + data: ChatData, chat_engine: BaseChatEngine = Depends(get_chat_engine), ): - last_message_content, messages = await parse_chat_data(data) + last_message_content = data.get_last_message_content() + messages = data.get_history_messages() event_handler = EventCallbackHandler() chat_engine.callback_manager.handlers.append(event_handler) # type: ignore try: response = await chat_engine.astream_chat(last_message_content, messages) except Exception as e: + logger.exception("Error in chat engine", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error in chat engine: {e}", - ) + ) from e async def content_generator(): + # Yield the additional data + if data.data is not None: + for data_response in data.get_additional_data_response(): + yield VercelStreamResponse.convert_data(data_response) + # Yield the text response async def _text_generator(): async for token in response.async_response_gen(): @@ -167,7 +75,7 @@ async def chat( "type": "sources", "data": { "nodes": [ - _SourceNodes.from_source_node(node).dict() + SourceNodes.from_source_node(node).dict() for node in response.source_nodes ] }, @@ -180,22 +88,23 @@ async def chat( # non-streaming endpoint - delete if not needed @r.post("/request") async def chat_request( - data: _ChatData, + data: ChatData, chat_engine: BaseChatEngine = Depends(get_chat_engine), -) -> _Result: - last_message_content, messages = await parse_chat_data(data) +) -> Result: + last_message_content = data.get_last_message_content() + messages = data.get_history_messages() response = await chat_engine.achat(last_message_content, messages) - return _Result( - result=_Message(role=MessageRole.ASSISTANT, content=response.response), - nodes=_SourceNodes.from_source_nodes(response.source_nodes), + return Result( + result=Message(role=MessageRole.ASSISTANT, content=response.response), + nodes=SourceNodes.from_source_nodes(response.source_nodes), ) @r.get("/config") -async def chat_config() -> _ChatConfig: +async def chat_config() -> ChatConfig: starter_questions = None conversation_starters = os.getenv("CONVERSATION_STARTERS") if conversation_starters and conversation_starters.strip(): starter_questions = conversation_starters.strip().split("\n") - return _ChatConfig(starterQuestions=starter_questions) + return ChatConfig(starterQuestions=starter_questions) diff --git a/templates/types/streaming/fastapi/app/api/routers/messaging.py b/templates/types/streaming/fastapi/app/api/routers/events.py similarity index 100% rename from templates/types/streaming/fastapi/app/api/routers/messaging.py rename to templates/types/streaming/fastapi/app/api/routers/events.py diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py new file mode 100644 index 0000000000000000000000000000000000000000..b64e86f0be0ed9d737b602d61769a36fc7d0cd45 --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -0,0 +1,175 @@ +import os +import logging +from pydantic import BaseModel, Field, validator +from pydantic.alias_generators import to_camel +from typing import List, Any, Optional, Dict +from llama_index.core.schema import NodeWithScore +from llama_index.core.llms import ChatMessage, MessageRole + + +logger = logging.getLogger("uvicorn") + + +class Message(BaseModel): + role: MessageRole + content: str + + +class CsvFile(BaseModel): + content: str + filename: str + filesize: int + id: str + type: str + + +class DataParserOptions(BaseModel): + csv_files: List[CsvFile] | None = Field( + default=None, + description="List of CSV files", + ) + + class Config: + json_schema_extra = { + "example": { + "csvFiles": [ + { + "content": "Name, Age\nAlice, 25\nBob, 30", + "filename": "example.csv", + "filesize": 123, + "id": "123", + "type": "text/csv", + } + ] + } + } + alias_generator = to_camel + + def to_raw_content(self) -> str: + if self.csv_files is not None and len(self.csv_files) > 0: + return "Use data from following CSV raw contents" + "\n".join( + [f"```csv\n{csv_file.content}\n```" for csv_file in self.csv_files] + ) + + def to_response_data(self) -> list[dict] | None: + output = [] + if self.csv_files is not None and len(self.csv_files) > 0: + output.append( + { + "type": "csv", + "data": { + "csvFiles": [csv_file.dict() for csv_file in self.csv_files] + }, + } + ) + return output if len(output) > 0 else None + + +class ChatData(BaseModel): + data: DataParserOptions | None = Field( + default=None, + ) + messages: List[Message] + + class Config: + json_schema_extra = { + "example": { + "messages": [ + { + "role": "user", + "content": "What standards for letters exist?", + } + ] + } + } + + @validator("messages") + def messages_must_not_be_empty(cls, v): + if len(v) == 0: + raise ValueError("Messages must not be empty") + return v + + def get_last_message_content(self) -> str: + """ + Get the content of the last message along with the data content if available + """ + message_content = self.messages[-1].content + if self.data: + message_content += "\n" + self.data.to_raw_content() + return message_content + + def get_history_messages(self) -> List[Message]: + """ + Get the history messages + """ + return [ + ChatMessage(role=message.role, content=message.content) + for message in self.messages[:-1] + ] + + def get_additional_data_response(self) -> list[dict] | None: + """ + Get the additional data + """ + return self.data.to_response_data() + + def is_last_message_from_user(self) -> bool: + return self.messages[-1].role == MessageRole.USER + + +class SourceNodes(BaseModel): + id: str + metadata: Dict[str, Any] + score: Optional[float] + text: str + url: Optional[str] + + @classmethod + def from_source_node(cls, source_node: NodeWithScore): + metadata = source_node.node.metadata + url = metadata.get("URL") + + if not url: + file_name = metadata.get("file_name") + url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if not url_prefix: + logger.warning( + "Warning: FILESERVER_URL_PREFIX not set in environment variables" + ) + if file_name and url_prefix: + url = f"{url_prefix}/data/{file_name}" + + return cls( + id=source_node.node.node_id, + metadata=metadata, + score=source_node.score, + text=source_node.node.text, # type: ignore + url=url, + ) + + @classmethod + def from_source_nodes(cls, source_nodes: List[NodeWithScore]): + return [cls.from_source_node(node) for node in source_nodes] + + +class Result(BaseModel): + result: Message + nodes: List[SourceNodes] + + +class ChatConfig(BaseModel): + starter_questions: Optional[List[str]] = Field( + default=None, + description="List of starter questions", + ) + + class Config: + json_schema_extra = { + "example": { + "starterQuestions": [ + "What standards for letters exist?", + "What are the requirements for a letter to be considered a letter?", + ] + } + } + alias_generator = to_camel