From d38eb3c40506c293e27cc4daacbcd7e00a4449ee Mon Sep 17 00:00:00 2001 From: leehuwuj <leehuwuj@gmail.com> Date: Tue, 25 Feb 2025 11:02:10 +0700 Subject: [PATCH] unify chat.py file --- .../multiagent/python/app/api/routers/chat.py | 57 ------------------- .../streaming/fastapi/app/api/routers/chat.py | 54 +++--------------- .../fastapi/app/workflows/__init__.py | 1 + .../{engine/engine.py => workflows/agent.py} | 3 +- 4 files changed, 11 insertions(+), 104 deletions(-) delete mode 100644 templates/components/multiagent/python/app/api/routers/chat.py create mode 100644 templates/types/streaming/fastapi/app/workflows/__init__.py rename templates/types/streaming/fastapi/app/{engine/engine.py => workflows/agent.py} (96%) diff --git a/templates/components/multiagent/python/app/api/routers/chat.py b/templates/components/multiagent/python/app/api/routers/chat.py deleted file mode 100644 index f46f43e1..00000000 --- a/templates/components/multiagent/python/app/api/routers/chat.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging - -from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status - -from app.api.callbacks.llamacloud import LlamaCloudFileDownload -from app.api.callbacks.next_question import SuggestNextQuestions -from app.api.callbacks.stream_handler import StreamHandler -from app.api.callbacks.source_nodes import AddNodeUrl -from app.api.routers.models import ( - ChatData, -) -from app.engine.query_filter import generate_filters -from app.workflows import create_workflow - -chat_router = r = APIRouter() - -logger = logging.getLogger("uvicorn") - - -@r.post("") -async def chat( - request: Request, - data: ChatData, - background_tasks: BackgroundTasks, -): - try: - last_message_content = data.get_last_message_content() - messages = data.get_history_messages(include_agent_messages=True) - - doc_ids = data.get_chat_document_ids() - filters = generate_filters(doc_ids) - params = data.data or {} - - workflow = create_workflow( - params=params, - filters=filters, - ) - - handler = workflow.run( - user_msg=last_message_content, - chat_history=messages, - stream=True, - ) - return StreamHandler.from_default( - handler=handler, - callbacks=[ - LlamaCloudFileDownload.from_default(background_tasks), - SuggestNextQuestions.from_default(data), - AddNodeUrl.from_default(), - ], - ).vercel_stream() - except Exception as e: - logger.exception("Error in chat engine", exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error in chat engine: {e}", - ) from e diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 5094adec..f46f43e1 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,28 +1,22 @@ -import json import logging from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status -from llama_index.core.agent.workflow import AgentOutput -from llama_index.core.llms import MessageRole from app.api.callbacks.llamacloud import LlamaCloudFileDownload from app.api.callbacks.next_question import SuggestNextQuestions -from app.api.callbacks.source_nodes import AddNodeUrl from app.api.callbacks.stream_handler import StreamHandler +from app.api.callbacks.source_nodes import AddNodeUrl from app.api.routers.models import ( ChatData, - Message, - Result, ) -from app.engine.engine import get_engine from app.engine.query_filter import generate_filters +from app.workflows import create_workflow chat_router = r = APIRouter() logger = logging.getLogger("uvicorn") -# streaming endpoint - delete if not needed @r.post("") async def chat( request: Request, @@ -31,16 +25,18 @@ async def chat( ): try: last_message_content = data.get_last_message_content() - messages = data.get_history_messages() + messages = data.get_history_messages(include_agent_messages=True) doc_ids = data.get_chat_document_ids() filters = generate_filters(doc_ids) params = data.data or {} - logger.info( - f"Creating chat engine with filters: {str(filters)}", + + workflow = create_workflow( + params=params, + filters=filters, ) - engine = get_engine(filters=filters, params=params) - handler = engine.run( + + handler = workflow.run( user_msg=last_message_content, chat_history=messages, stream=True, @@ -59,35 +55,3 @@ async def chat( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error in chat engine: {e}", ) from e - - -# non-streaming endpoint - delete if not needed -@r.post("/request") -async def chat_request( - data: ChatData, -) -> Result: - last_message_content = data.get_last_message_content() - messages = data.get_history_messages() - - doc_ids = data.get_chat_document_ids() - filters = generate_filters(doc_ids) - params = data.data or {} - logger.info( - f"Creating chat engine with filters: {str(filters)}", - ) - engine = get_engine(filters=filters, params=params) - - response = await engine.run( - user_msg=last_message_content, - chat_history=messages, - stream=False, - ) - output = response - if isinstance(output, AgentOutput): - content = output.response.content - else: - content = json.dumps(output) - - return Result( - result=Message(role=MessageRole.ASSISTANT, content=content), - ) diff --git a/templates/types/streaming/fastapi/app/workflows/__init__.py b/templates/types/streaming/fastapi/app/workflows/__init__.py new file mode 100644 index 00000000..f0172c6d --- /dev/null +++ b/templates/types/streaming/fastapi/app/workflows/__init__.py @@ -0,0 +1 @@ +from .agent import create_workflow diff --git a/templates/types/streaming/fastapi/app/engine/engine.py b/templates/types/streaming/fastapi/app/workflows/agent.py similarity index 96% rename from templates/types/streaming/fastapi/app/engine/engine.py rename to templates/types/streaming/fastapi/app/workflows/agent.py index c1fc25f2..6dcfd76e 100644 --- a/templates/types/streaming/fastapi/app/engine/engine.py +++ b/templates/types/streaming/fastapi/app/workflows/agent.py @@ -2,7 +2,6 @@ import os from typing import List from llama_index.core.agent.workflow import AgentWorkflow - from llama_index.core.settings import Settings from llama_index.core.tools import BaseTool @@ -11,7 +10,7 @@ from app.engine.tools import ToolFactory from app.engine.tools.query_engine import get_query_engine_tool -def get_engine(params=None, **kwargs): +def create_workflow(params=None, **kwargs): if params is None: params = {} system_prompt = os.getenv("SYSTEM_PROMPT") -- GitLab