Skip to content
Snippets Groups Projects
Commit d38eb3c4 authored by leehuwuj's avatar leehuwuj
Browse files

unify chat.py file

parent 087a45e9
No related branches found
No related tags found
No related merge requests found
import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.routers.models import (
ChatData,
)
from app.engine.query_filter import generate_filters
from app.workflows import create_workflow
chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
@r.post("")
async def chat(
request: Request,
data: ChatData,
background_tasks: BackgroundTasks,
):
try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages(include_agent_messages=True)
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
workflow = create_workflow(
params=params,
filters=filters,
)
handler = workflow.run(
user_msg=last_message_content,
chat_history=messages,
stream=True,
)
return StreamHandler.from_default(
handler=handler,
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
AddNodeUrl.from_default(),
],
).vercel_stream()
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e
import json
import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from llama_index.core.agent.workflow import AgentOutput
from llama_index.core.llms import MessageRole
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.callbacks.stream_handler import StreamHandler
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.routers.models import (
ChatData,
Message,
Result,
)
from app.engine.engine import get_engine
from app.engine.query_filter import generate_filters
from app.workflows import create_workflow
chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
......@@ -31,16 +25,18 @@ async def chat(
):
try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
messages = data.get_history_messages(include_agent_messages=True)
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info(
f"Creating chat engine with filters: {str(filters)}",
workflow = create_workflow(
params=params,
filters=filters,
)
engine = get_engine(filters=filters, params=params)
handler = engine.run(
handler = workflow.run(
user_msg=last_message_content,
chat_history=messages,
stream=True,
......@@ -59,35 +55,3 @@ async def chat(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: ChatData,
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info(
f"Creating chat engine with filters: {str(filters)}",
)
engine = get_engine(filters=filters, params=params)
response = await engine.run(
user_msg=last_message_content,
chat_history=messages,
stream=False,
)
output = response
if isinstance(output, AgentOutput):
content = output.response.content
else:
content = json.dumps(output)
return Result(
result=Message(role=MessageRole.ASSISTANT, content=content),
)
from .agent import create_workflow
......@@ -2,7 +2,6 @@ import os
from typing import List
from llama_index.core.agent.workflow import AgentWorkflow
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
......@@ -11,7 +10,7 @@ from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
def get_engine(params=None, **kwargs):
def create_workflow(params=None, **kwargs):
if params is None:
params = {}
system_prompt = os.getenv("SYSTEM_PROMPT")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment