Skip to content
Snippets Groups Projects
Commit b4f07672 authored by leehuwuj's avatar leehuwuj
Browse files

stg

parent 9e723c3a
No related branches found
No related tags found
No related merge requests found
Showing
with 92 additions and 259 deletions
import os import os
from typing import List from typing import List
from llama_index.core.agent import AgentRunner from llama_index.core.agent.workflow import AgentWorkflow
from llama_index.core.callbacks import CallbackManager
# from llama_index.core.agent import AgentRunner
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool from llama_index.core.tools import BaseTool
...@@ -11,13 +12,14 @@ from app.engine.tools import ToolFactory ...@@ -11,13 +12,14 @@ from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool from app.engine.tools.query_engine import get_query_engine_tool
def get_chat_engine(params=None, event_handlers=None, **kwargs): def get_engine(params=None, **kwargs):
if params is None:
params = {}
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
tools: List[BaseTool] = [] tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])
# Add query tool if index exists # Add query tool if index exists
index_config = IndexConfig(callback_manager=callback_manager, **(params or {})) index_config = IndexConfig(**params)
index = get_index(index_config) index = get_index(index_config)
if index is not None: if index is not None:
query_engine_tool = get_query_engine_tool(index, **kwargs) query_engine_tool = get_query_engine_tool(index, **kwargs)
...@@ -27,10 +29,8 @@ def get_chat_engine(params=None, event_handlers=None, **kwargs): ...@@ -27,10 +29,8 @@ def get_chat_engine(params=None, event_handlers=None, **kwargs):
configured_tools: List[BaseTool] = ToolFactory.from_env() configured_tools: List[BaseTool] = ToolFactory.from_env()
tools.extend(configured_tools) tools.extend(configured_tools)
return AgentRunner.from_llm( return AgentWorkflow.from_tools_or_functions(
tools_or_functions=tools, # type: ignore
llm=Settings.llm, llm=Settings.llm,
tools=tools,
system_prompt=system_prompt, system_prompt=system_prompt,
callback_manager=callback_manager,
verbose=True,
) )
import asyncio
import json
import logging
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
from llama_index.core.agent.workflow.workflow_events import AgentStream
from llama_index.core.workflow import StopEvent
from app.api.callbacks.stream_handler import StreamHandler
logger = logging.getLogger("uvicorn")
class VercelStreamResponse(StreamingResponse):
"""
Converts preprocessed events into Vercel-compatible streaming response format.
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "8:"
ERROR_PREFIX = "3:"
def __init__(
self,
stream_handler: StreamHandler,
*args,
**kwargs,
):
self.handler = stream_handler
super().__init__(content=self.content_generator())
async def content_generator(self):
"""Generate Vercel-formatted content from preprocessed events."""
stream_started = False
try:
async for event in self.handler.stream_events():
if not stream_started:
# Start the stream with an empty message
stream_started = True
yield self.convert_text("")
# Handle different types of events
if isinstance(event, (AgentStream, StopEvent)):
async for chunk in self._stream_text(event):
await self.handler.accumulate_text(chunk)
yield self.convert_text(chunk)
elif isinstance(event, dict):
yield self.convert_data(event)
elif hasattr(event, "to_response"):
event_response = event.to_response()
yield self.convert_data(event_response)
else:
yield self.convert_data(event.model_dump())
except asyncio.CancelledError:
logger.warning("Client cancelled the request!")
await self.handler.cancel_run()
except Exception as e:
logger.error(f"Error in stream response: {e}")
yield self.convert_error(str(e))
await self.handler.cancel_run()
async def _stream_text(
self, event: AgentStream | StopEvent
) -> AsyncGenerator[str, None]:
"""
Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result
"""
if isinstance(event, AgentStream):
yield self.convert_text(event.delta)
elif isinstance(event, StopEvent):
if isinstance(event.result, str):
yield event.result
elif isinstance(event.result, AsyncGenerator):
async for chunk in event.result:
if isinstance(chunk, str):
yield chunk
elif hasattr(chunk, "delta"):
yield chunk.delta
@classmethod
def convert_text(cls, token: str) -> str:
"""Convert text event to Vercel format."""
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod
def convert_data(cls, data: dict) -> str:
"""Convert data event to Vercel format."""
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
@classmethod
def convert_error(cls, error: str) -> str:
"""Convert error event to Vercel format."""
error_str = json.dumps(error)
return f"{cls.ERROR_PREFIX}{error_str}\n"
...@@ -3,15 +3,16 @@ import logging ...@@ -3,15 +3,16 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from llama_index.core.llms import MessageRole from llama_index.core.llms import MessageRole
from app.api.routers.events import EventCallbackHandler from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler
from app.api.routers.models import ( from app.api.routers.models import (
ChatData, ChatData,
Message, Message,
Result, Result,
SourceNodes, SourceNodes,
) )
from app.api.routers.vercel_response import VercelStreamResponse from app.engine.engine import get_engine
from app.engine.engine import get_chat_engine
from app.engine.query_filter import generate_filters from app.engine.query_filter import generate_filters
chat_router = r = APIRouter() chat_router = r = APIRouter()
...@@ -36,15 +37,19 @@ async def chat( ...@@ -36,15 +37,19 @@ async def chat(
logger.info( logger.info(
f"Creating chat engine with filters: {str(filters)}", f"Creating chat engine with filters: {str(filters)}",
) )
event_handler = EventCallbackHandler() engine = get_engine(filters=filters, params=params)
chat_engine = get_chat_engine( handler = engine.run(
filters=filters, params=params, event_handlers=[event_handler] user_msg=last_message_content,
) chat_history=messages,
response = chat_engine.astream_chat(last_message_content, messages) stream=True,
return VercelStreamResponse(
request, event_handler, response, data, background_tasks
) )
return StreamHandler.from_default(
handler=handler,
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
],
).vercel_stream()
except Exception as e: except Exception as e:
logger.exception("Error in chat engine", exc_info=True) logger.exception("Error in chat engine", exc_info=True)
raise HTTPException( raise HTTPException(
...@@ -53,6 +58,7 @@ async def chat( ...@@ -53,6 +58,7 @@ async def chat(
) from e ) from e
# TODO: Update non-streaming endpoint
# non-streaming endpoint - delete if not needed # non-streaming endpoint - delete if not needed
@r.post("/request") @r.post("/request")
async def chat_request( async def chat_request(
......
...@@ -99,6 +99,8 @@ class CallbackEvent(BaseModel): ...@@ -99,6 +99,8 @@ class CallbackEvent(BaseModel):
return None return None
# TODO: Add an adapter for workflow events
# and remove callback handler
class EventCallbackHandler(BaseCallbackHandler): class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue _aqueue: asyncio.Queue
is_done: bool = False is_done: bool = False
......
import asyncio
import json import json
import logging import logging
from typing import Awaitable, List from typing import AsyncGenerator
from aiostream import stream
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.agent.workflow.workflow_events import AgentStream
from llama_index.core.schema import NodeWithScore from llama_index.core.workflow import StopEvent
from app.api.routers.events import EventCallbackHandler from app.api.callbacks.stream_handler import StreamHandler
from app.api.routers.models import ChatData, Message, SourceNodes
from app.api.services.suggestion import NextQuestionSuggestion
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
class VercelStreamResponse(StreamingResponse): class VercelStreamResponse(StreamingResponse):
""" """
Class to convert the response from the chat engine to the streaming format expected by Vercel Converts preprocessed events into Vercel-compatible streaming response format.
""" """
TEXT_PREFIX = "0:" TEXT_PREFIX = "0:"
...@@ -26,152 +23,79 @@ class VercelStreamResponse(StreamingResponse): ...@@ -26,152 +23,79 @@ class VercelStreamResponse(StreamingResponse):
def __init__( def __init__(
self, self,
request: Request, stream_handler: StreamHandler,
event_handler: EventCallbackHandler, *args,
response: Awaitable[StreamingAgentChatResponse], **kwargs,
chat_data: ChatData,
background_tasks: BackgroundTasks,
): ):
content = VercelStreamResponse.content_generator( self.handler = stream_handler
request, event_handler, response, chat_data, background_tasks super().__init__(content=self.content_generator())
)
super().__init__(content=content)
@classmethod async def content_generator(self):
async def content_generator( """Generate Vercel-formatted content from preprocessed events."""
cls, stream_started = False
request: Request,
event_handler: EventCallbackHandler,
response: Awaitable[StreamingAgentChatResponse],
chat_data: ChatData,
background_tasks: BackgroundTasks,
):
chat_response_generator = cls._chat_response_generator(
response, background_tasks, event_handler, chat_data
)
event_generator = cls._event_generator(event_handler)
# Merge the chat response generator and the event generator
combine = stream.merge(chat_response_generator, event_generator)
is_stream_started = False
try: try:
async with combine.stream() as streamer: async for event in self.handler.stream_events():
async for output in streamer: if not stream_started:
if await request.is_disconnected(): # Start the stream with an empty message
break stream_started = True
yield self.convert_text("")
if not is_stream_started:
is_stream_started = True # Handle different types of events
# Stream a blank message to start displaying the response in the UI if isinstance(event, (AgentStream, StopEvent)):
yield cls.convert_text("") async for chunk in self._stream_text(event):
await self.handler.accumulate_text(chunk)
yield output yield self.convert_text(chunk)
except Exception: elif isinstance(event, dict):
logger.exception("Error in stream response") yield self.convert_data(event)
yield cls.convert_error( elif hasattr(event, "to_response"):
"An unexpected error occurred while processing your request, preventing the creation of a final answer. Please try again." event_response = event.to_response()
) yield self.convert_data(event_response)
finally: else:
# Ensure event handler is marked as done even if connection breaks yield self.convert_data(
event_handler.is_done = True {"type": "agent", "data": event.model_dump()}
)
@classmethod
async def _event_generator(cls, event_handler: EventCallbackHandler): except asyncio.CancelledError:
""" logger.warning("Client cancelled the request!")
Yield the events from the event handler await self.handler.cancel_run()
""" except Exception as e:
async for event in event_handler.async_event_gen(): logger.error(f"Error in stream response: {e}")
event_response = event.to_response() yield self.convert_error(str(e))
if event_response is not None: await self.handler.cancel_run()
yield cls.convert_data(event_response)
async def _stream_text(
@classmethod self, event: AgentStream | StopEvent
async def _chat_response_generator( ) -> AsyncGenerator[str, None]:
cls,
response: Awaitable[StreamingAgentChatResponse],
background_tasks: BackgroundTasks,
event_handler: EventCallbackHandler,
chat_data: ChatData,
):
""" """
Yield the text response and source nodes from the chat engine Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result
""" """
# Wait for the response from the chat engine if isinstance(event, AgentStream):
result = await response yield event.delta
elif isinstance(event, StopEvent):
# Once we got a source node, start a background task to download the files (if needed) if isinstance(event.result, str):
cls._process_response_nodes(result.source_nodes, background_tasks) yield event.result
elif isinstance(event.result, AsyncGenerator):
# Yield the source nodes async for chunk in event.result:
yield cls.convert_data( if isinstance(chunk, str):
{ yield chunk
"type": "sources", elif hasattr(chunk, "delta"):
"data": { yield chunk.delta
"nodes": [
SourceNodes.from_source_node(node).model_dump()
for node in result.source_nodes
]
},
}
)
final_response = ""
async for token in result.async_response_gen():
final_response += token
yield cls.convert_text(token)
# Generate next questions if next question prompt is configured
question_data = await cls._generate_next_questions(
chat_data.messages, final_response
)
if question_data:
yield cls.convert_data(question_data)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
@classmethod @classmethod
def convert_text(cls, token: str): def convert_text(cls, token: str) -> str:
"""Convert text event to Vercel format."""
# Escape newlines and double quotes to avoid breaking the stream # Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token) token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n" return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod @classmethod
def convert_data(cls, data: dict): def convert_data(cls, data: dict) -> str:
"""Convert data event to Vercel format."""
data_str = json.dumps(data) data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n" return f"{cls.DATA_PREFIX}[{data_str}]\n"
@classmethod @classmethod
def convert_error(cls, error: str): def convert_error(cls, error: str) -> str:
"""Convert error event to Vercel format."""
error_str = json.dumps(error) error_str = json.dumps(error)
return f"{cls.ERROR_PREFIX}{error_str}\n" return f"{cls.ERROR_PREFIX}{error_str}\n"
@staticmethod
def _process_response_nodes(
source_nodes: List[NodeWithScore],
background_tasks: BackgroundTasks,
):
try:
# Start background tasks to download documents from LlamaCloud if needed
from app.engine.service import LLamaCloudFileService # type: ignore
LLamaCloudFileService.download_files_from_nodes(
source_nodes, background_tasks
)
except ImportError:
logger.debug(
"LlamaCloud is not configured. Skipping post processing of nodes"
)
pass
@staticmethod
async def _generate_next_questions(chat_history: List[Message], response: str):
questions = await NextQuestionSuggestion.suggest_next_questions(
chat_history, response
)
if questions:
return {
"type": "suggested_questions",
"data": questions,
}
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment