Skip to content
Snippets Groups Projects
Commit b4f07672 authored by leehuwuj's avatar leehuwuj
Browse files

stg

parent 9e723c3a
No related branches found
No related tags found
No related merge requests found
Showing
with 92 additions and 259 deletions
import os
from typing import List
from llama_index.core.agent import AgentRunner
from llama_index.core.callbacks import CallbackManager
from llama_index.core.agent.workflow import AgentWorkflow
# from llama_index.core.agent import AgentRunner
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
......@@ -11,13 +12,14 @@ from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
def get_chat_engine(params=None, event_handlers=None, **kwargs):
def get_engine(params=None, **kwargs):
if params is None:
params = {}
system_prompt = os.getenv("SYSTEM_PROMPT")
tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])
# Add query tool if index exists
index_config = IndexConfig(callback_manager=callback_manager, **(params or {}))
index_config = IndexConfig(**params)
index = get_index(index_config)
if index is not None:
query_engine_tool = get_query_engine_tool(index, **kwargs)
......@@ -27,10 +29,8 @@ def get_chat_engine(params=None, event_handlers=None, **kwargs):
configured_tools: List[BaseTool] = ToolFactory.from_env()
tools.extend(configured_tools)
return AgentRunner.from_llm(
return AgentWorkflow.from_tools_or_functions(
tools_or_functions=tools, # type: ignore
llm=Settings.llm,
tools=tools,
system_prompt=system_prompt,
callback_manager=callback_manager,
verbose=True,
)
import asyncio
import json
import logging
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
from llama_index.core.agent.workflow.workflow_events import AgentStream
from llama_index.core.workflow import StopEvent
from app.api.callbacks.stream_handler import StreamHandler
logger = logging.getLogger("uvicorn")
class VercelStreamResponse(StreamingResponse):
"""
Converts preprocessed events into Vercel-compatible streaming response format.
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "8:"
ERROR_PREFIX = "3:"
def __init__(
self,
stream_handler: StreamHandler,
*args,
**kwargs,
):
self.handler = stream_handler
super().__init__(content=self.content_generator())
async def content_generator(self):
"""Generate Vercel-formatted content from preprocessed events."""
stream_started = False
try:
async for event in self.handler.stream_events():
if not stream_started:
# Start the stream with an empty message
stream_started = True
yield self.convert_text("")
# Handle different types of events
if isinstance(event, (AgentStream, StopEvent)):
async for chunk in self._stream_text(event):
await self.handler.accumulate_text(chunk)
yield self.convert_text(chunk)
elif isinstance(event, dict):
yield self.convert_data(event)
elif hasattr(event, "to_response"):
event_response = event.to_response()
yield self.convert_data(event_response)
else:
yield self.convert_data(event.model_dump())
except asyncio.CancelledError:
logger.warning("Client cancelled the request!")
await self.handler.cancel_run()
except Exception as e:
logger.error(f"Error in stream response: {e}")
yield self.convert_error(str(e))
await self.handler.cancel_run()
async def _stream_text(
self, event: AgentStream | StopEvent
) -> AsyncGenerator[str, None]:
"""
Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result
"""
if isinstance(event, AgentStream):
yield self.convert_text(event.delta)
elif isinstance(event, StopEvent):
if isinstance(event.result, str):
yield event.result
elif isinstance(event.result, AsyncGenerator):
async for chunk in event.result:
if isinstance(chunk, str):
yield chunk
elif hasattr(chunk, "delta"):
yield chunk.delta
@classmethod
def convert_text(cls, token: str) -> str:
"""Convert text event to Vercel format."""
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod
def convert_data(cls, data: dict) -> str:
"""Convert data event to Vercel format."""
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
@classmethod
def convert_error(cls, error: str) -> str:
"""Convert error event to Vercel format."""
error_str = json.dumps(error)
return f"{cls.ERROR_PREFIX}{error_str}\n"
......@@ -3,15 +3,16 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from llama_index.core.llms import MessageRole
from app.api.routers.events import EventCallbackHandler
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler
from app.api.routers.models import (
ChatData,
Message,
Result,
SourceNodes,
)
from app.api.routers.vercel_response import VercelStreamResponse
from app.engine.engine import get_chat_engine
from app.engine.engine import get_engine
from app.engine.query_filter import generate_filters
chat_router = r = APIRouter()
......@@ -36,15 +37,19 @@ async def chat(
logger.info(
f"Creating chat engine with filters: {str(filters)}",
)
event_handler = EventCallbackHandler()
chat_engine = get_chat_engine(
filters=filters, params=params, event_handlers=[event_handler]
)
response = chat_engine.astream_chat(last_message_content, messages)
return VercelStreamResponse(
request, event_handler, response, data, background_tasks
engine = get_engine(filters=filters, params=params)
handler = engine.run(
user_msg=last_message_content,
chat_history=messages,
stream=True,
)
return StreamHandler.from_default(
handler=handler,
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
],
).vercel_stream()
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
......@@ -53,6 +58,7 @@ async def chat(
) from e
# TODO: Update non-streaming endpoint
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
......
......@@ -99,6 +99,8 @@ class CallbackEvent(BaseModel):
return None
# TODO: Add an adapter for workflow events
# and remove callback handler
class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
......
import asyncio
import json
import logging
from typing import Awaitable, List
from typing import AsyncGenerator
from aiostream import stream
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.schema import NodeWithScore
from llama_index.core.agent.workflow.workflow_events import AgentStream
from llama_index.core.workflow import StopEvent
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import ChatData, Message, SourceNodes
from app.api.services.suggestion import NextQuestionSuggestion
from app.api.callbacks.stream_handler import StreamHandler
logger = logging.getLogger("uvicorn")
class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel
Converts preprocessed events into Vercel-compatible streaming response format.
"""
TEXT_PREFIX = "0:"
......@@ -26,152 +23,79 @@ class VercelStreamResponse(StreamingResponse):
def __init__(
self,
request: Request,
event_handler: EventCallbackHandler,
response: Awaitable[StreamingAgentChatResponse],
chat_data: ChatData,
background_tasks: BackgroundTasks,
stream_handler: StreamHandler,
*args,
**kwargs,
):
content = VercelStreamResponse.content_generator(
request, event_handler, response, chat_data, background_tasks
)
super().__init__(content=content)
self.handler = stream_handler
super().__init__(content=self.content_generator())
@classmethod
async def content_generator(
cls,
request: Request,
event_handler: EventCallbackHandler,
response: Awaitable[StreamingAgentChatResponse],
chat_data: ChatData,
background_tasks: BackgroundTasks,
):
chat_response_generator = cls._chat_response_generator(
response, background_tasks, event_handler, chat_data
)
event_generator = cls._event_generator(event_handler)
# Merge the chat response generator and the event generator
combine = stream.merge(chat_response_generator, event_generator)
is_stream_started = False
async def content_generator(self):
"""Generate Vercel-formatted content from preprocessed events."""
stream_started = False
try:
async with combine.stream() as streamer:
async for output in streamer:
if await request.is_disconnected():
break
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start displaying the response in the UI
yield cls.convert_text("")
yield output
except Exception:
logger.exception("Error in stream response")
yield cls.convert_error(
"An unexpected error occurred while processing your request, preventing the creation of a final answer. Please try again."
)
finally:
# Ensure event handler is marked as done even if connection breaks
event_handler.is_done = True
@classmethod
async def _event_generator(cls, event_handler: EventCallbackHandler):
"""
Yield the events from the event handler
"""
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield cls.convert_data(event_response)
@classmethod
async def _chat_response_generator(
cls,
response: Awaitable[StreamingAgentChatResponse],
background_tasks: BackgroundTasks,
event_handler: EventCallbackHandler,
chat_data: ChatData,
):
async for event in self.handler.stream_events():
if not stream_started:
# Start the stream with an empty message
stream_started = True
yield self.convert_text("")
# Handle different types of events
if isinstance(event, (AgentStream, StopEvent)):
async for chunk in self._stream_text(event):
await self.handler.accumulate_text(chunk)
yield self.convert_text(chunk)
elif isinstance(event, dict):
yield self.convert_data(event)
elif hasattr(event, "to_response"):
event_response = event.to_response()
yield self.convert_data(event_response)
else:
yield self.convert_data(
{"type": "agent", "data": event.model_dump()}
)
except asyncio.CancelledError:
logger.warning("Client cancelled the request!")
await self.handler.cancel_run()
except Exception as e:
logger.error(f"Error in stream response: {e}")
yield self.convert_error(str(e))
await self.handler.cancel_run()
async def _stream_text(
self, event: AgentStream | StopEvent
) -> AsyncGenerator[str, None]:
"""
Yield the text response and source nodes from the chat engine
Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result
"""
# Wait for the response from the chat engine
result = await response
# Once we got a source node, start a background task to download the files (if needed)
cls._process_response_nodes(result.source_nodes, background_tasks)
# Yield the source nodes
yield cls.convert_data(
{
"type": "sources",
"data": {
"nodes": [
SourceNodes.from_source_node(node).model_dump()
for node in result.source_nodes
]
},
}
)
final_response = ""
async for token in result.async_response_gen():
final_response += token
yield cls.convert_text(token)
# Generate next questions if next question prompt is configured
question_data = await cls._generate_next_questions(
chat_data.messages, final_response
)
if question_data:
yield cls.convert_data(question_data)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
if isinstance(event, AgentStream):
yield event.delta
elif isinstance(event, StopEvent):
if isinstance(event.result, str):
yield event.result
elif isinstance(event.result, AsyncGenerator):
async for chunk in event.result:
if isinstance(chunk, str):
yield chunk
elif hasattr(chunk, "delta"):
yield chunk.delta
@classmethod
def convert_text(cls, token: str):
def convert_text(cls, token: str) -> str:
"""Convert text event to Vercel format."""
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod
def convert_data(cls, data: dict):
def convert_data(cls, data: dict) -> str:
"""Convert data event to Vercel format."""
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
@classmethod
def convert_error(cls, error: str):
def convert_error(cls, error: str) -> str:
"""Convert error event to Vercel format."""
error_str = json.dumps(error)
return f"{cls.ERROR_PREFIX}{error_str}\n"
@staticmethod
def _process_response_nodes(
source_nodes: List[NodeWithScore],
background_tasks: BackgroundTasks,
):
try:
# Start background tasks to download documents from LlamaCloud if needed
from app.engine.service import LLamaCloudFileService # type: ignore
LLamaCloudFileService.download_files_from_nodes(
source_nodes, background_tasks
)
except ImportError:
logger.debug(
"LlamaCloud is not configured. Skipping post processing of nodes"
)
pass
@staticmethod
async def _generate_next_questions(chat_history: List[Message], response: str):
questions = await NextQuestionSuggestion.suggest_next_questions(
chat_history, response
)
if questions:
return {
"type": "suggested_questions",
"data": questions,
}
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment