Skip to content
Snippets Groups Projects
Unverified Commit 9e723c3a authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

Standardize the code of workflow use cases (#495)

parent d5da55b9
No related branches found
No related tags found
No related merge requests found
Showing
with 304 additions and 177 deletions
---
"create-llama": patch
---
Standardize the code of the workflow use case (Python)
......@@ -32,7 +32,6 @@ logger.setLevel(logging.INFO)
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
......@@ -45,7 +44,6 @@ def create_workflow(
return DeepResearchWorkflow(
index=index,
chat_history=chat_history,
timeout=120.0,
)
......@@ -73,19 +71,13 @@ class DeepResearchWorkflow(Workflow):
def __init__(
self,
index: BaseIndex,
chat_history: Optional[List[ChatMessage]] = None,
stream: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.index = index
self.context_nodes = []
self.stream = stream
self.chat_history = chat_history
self.memory = SimpleComposableMemory.from_defaults(
primary_memory=ChatMemoryBuffer.from_defaults(
chat_history=chat_history,
),
primary_memory=ChatMemoryBuffer.from_defaults(),
)
@step
......@@ -93,8 +85,15 @@ class DeepResearchWorkflow(Workflow):
"""
Initiate the workflow: memory, tools, agent
"""
self.stream = ev.get("stream", True)
self.user_request = ev.get("user_msg")
chat_history = ev.get("chat_history")
if chat_history is not None:
self.memory.put_messages(chat_history)
await ctx.set("total_questions", 0)
self.user_request = ev.get("input")
# Add user message to memory
self.memory.put_messages(
messages=[
ChatMessage(
......@@ -319,7 +318,6 @@ class DeepResearchWorkflow(Workflow):
"""
Report the answers
"""
logger.info("Writing the report")
res = await write_report(
memory=self.memory,
user_request=self.user_request,
......
from typing import Any, Dict, List, Optional
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
......@@ -22,9 +14,17 @@ from llama_index.core.workflow import (
step,
)
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
......@@ -45,7 +45,6 @@ def create_workflow(
query_engine_tool=query_engine_tool,
code_interpreter_tool=code_interpreter_tool,
document_generator_tool=document_generator_tool,
chat_history=chat_history,
)
......@@ -91,6 +90,7 @@ class FinancialReportWorkflow(Workflow):
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
"""
stream: bool = True
def __init__(
self,
......@@ -99,12 +99,10 @@ class FinancialReportWorkflow(Workflow):
document_generator_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
chat_history: Optional[List[ChatMessage]] = None,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.chat_history = chat_history or []
self.query_engine_tool = query_engine_tool
self.code_interpreter_tool = code_interpreter_tool
self.document_generator_tool = document_generator_tool
......@@ -122,13 +120,19 @@ class FinancialReportWorkflow(Workflow):
]
self.llm: FunctionCallingLLM = llm or Settings.llm
assert isinstance(self.llm, FunctionCallingLLM)
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=self.chat_history
)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["input"] = ev.input
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg")
chat_history = ev.get("chat_history")
if chat_history is not None:
self.memory.put_messages(chat_history)
# Add user message to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
......@@ -136,9 +140,6 @@ class FinancialReportWorkflow(Workflow):
)
self.memory.put(system_msg)
# Add user input to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input))
return InputEvent(input=self.memory.get())
@step()
......@@ -160,8 +161,10 @@ class FinancialReportWorkflow(Workflow):
chat_history,
)
if not response.has_tool_calls():
# If no tool call, return the response generator
return StopEvent(result=response.generator)
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
......
......@@ -25,7 +25,6 @@ from app.workflows.tools import (
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
......@@ -45,7 +44,6 @@ def create_workflow(
query_engine_tool=query_engine_tool,
extractor_tool=extractor_tool, # type: ignore
filling_tool=filling_tool, # type: ignore
chat_history=chat_history,
)
return workflow
......@@ -88,6 +86,7 @@ class FormFillingWorkflow(Workflow):
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
"""
stream: bool = True
def __init__(
self,
......@@ -96,12 +95,10 @@ class FormFillingWorkflow(Workflow):
filling_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
chat_history: Optional[List[ChatMessage]] = None,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.chat_history = chat_history or []
self.query_engine_tool = query_engine_tool
self.extractor_tool = extractor_tool
self.filling_tool = filling_tool
......@@ -113,13 +110,18 @@ class FormFillingWorkflow(Workflow):
self.llm: FunctionCallingLLM = llm or Settings.llm
if not isinstance(self.llm, FunctionCallingLLM):
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=self.chat_history
)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["input"] = ev.input
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg", "")
chat_history = ev.get("chat_history", [])
if chat_history:
self.memory.put_messages(chat_history)
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
......@@ -127,12 +129,7 @@ class FormFillingWorkflow(Workflow):
)
self.memory.put(system_msg)
user_input = ev.input
user_msg = ChatMessage(role=MessageRole.USER, content=user_input)
self.memory.put(user_msg)
chat_history = self.memory.get()
return InputEvent(input=chat_history)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input( # type: ignore
......@@ -150,7 +147,10 @@ class FormFillingWorkflow(Workflow):
chat_history,
)
if not response.has_tool_calls():
return StopEvent(result=response.generator)
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
......
import logging
from abc import ABC, abstractmethod
from typing import Any
logger = logging.getLogger("uvicorn")
class EventCallback(ABC):
"""
Base class for event callbacks during event streaming.
"""
async def run(self, event: Any) -> Any:
"""
Called for each event in the stream.
Default behavior: pass through the event unchanged.
"""
return event
async def on_complete(self, final_response: str) -> Any:
"""
Called when the stream is complete.
Default behavior: return None.
"""
return None
@abstractmethod
def from_default(self, *args, **kwargs) -> "EventCallback":
"""
Create a new instance of the processor from default values.
"""
pass
import logging
from typing import Any, List
from fastapi import BackgroundTasks
from llama_index.core.schema import NodeWithScore
from app.api.callbacks.base import EventCallback
logger = logging.getLogger("uvicorn")
class LlamaCloudFileDownload(EventCallback):
"""
Processor for handling LlamaCloud file downloads from source nodes.
Only work if LlamaCloud service code is available.
"""
def __init__(self, background_tasks: BackgroundTasks):
self.background_tasks = background_tasks
async def run(self, event: Any) -> Any:
if hasattr(event, "to_response"):
event_response = event.to_response()
if event_response.get("type") == "sources" and hasattr(event, "nodes"):
await self._process_response_nodes(event.nodes)
return event
async def _process_response_nodes(self, source_nodes: List[NodeWithScore]):
try:
from app.engine.service import LLamaCloudFileService # type: ignore
LLamaCloudFileService.download_files_from_nodes(
source_nodes, self.background_tasks
)
except ImportError:
pass
@classmethod
def from_default(
cls, background_tasks: BackgroundTasks
) -> "LlamaCloudFileDownload":
return cls(background_tasks=background_tasks)
import logging
from typing import Any
from app.api.callbacks.base import EventCallback
from app.api.routers.models import ChatData
from app.api.services.suggestion import NextQuestionSuggestion
logger = logging.getLogger("uvicorn")
class SuggestNextQuestions(EventCallback):
"""Processor for generating next question suggestions."""
def __init__(self, chat_data: ChatData):
self.chat_data = chat_data
self.accumulated_text = ""
async def on_complete(self, final_response: str) -> Any:
if final_response == "":
return None
questions = await NextQuestionSuggestion.suggest_next_questions(
self.chat_data.messages, final_response
)
if questions:
return {
"type": "suggested_questions",
"data": questions,
}
return None
@classmethod
def from_default(cls, chat_data: ChatData) -> "SuggestNextQuestions":
return cls(chat_data=chat_data)
import logging
from typing import List, Optional
from llama_index.core.workflow.handler import WorkflowHandler
from app.api.callbacks.base import EventCallback
logger = logging.getLogger("uvicorn")
class StreamHandler:
"""
Streams events from a workflow handler through a chain of callbacks.
"""
def __init__(
self,
workflow_handler: WorkflowHandler,
callbacks: Optional[List[EventCallback]] = None,
):
self.workflow_handler = workflow_handler
self.callbacks = callbacks or []
self.accumulated_text = ""
def vercel_stream(self):
"""Create a streaming response with Vercel format."""
from app.api.routers.vercel_response import VercelStreamResponse
return VercelStreamResponse(stream_handler=self)
async def cancel_run(self):
"""Cancel the workflow handler."""
await self.workflow_handler.cancel_run()
async def stream_events(self):
"""Stream events through the processor chain."""
try:
async for event in self.workflow_handler.stream_events():
# Process the event through each processor
for callback in self.callbacks:
event = await callback.run(event)
yield event
# After all events are processed, call on_complete for each callback
for callback in self.callbacks:
result = await callback.on_complete(self.accumulated_text)
if result:
yield result
except Exception as e:
# Make sure to cancel the workflow on error
await self.workflow_handler.cancel_run()
raise e
async def accumulate_text(self, text: str):
"""Accumulate text from the workflow handler."""
self.accumulated_text += text
@classmethod
def from_default(
cls,
handler: WorkflowHandler,
callbacks: Optional[List[EventCallback]] = None,
) -> "StreamHandler":
"""Create a new instance with the given workflow handler and callbacks."""
return cls(workflow_handler=handler, callbacks=callbacks)
......@@ -2,10 +2,12 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler
from app.api.routers.models import (
ChatData,
)
from app.api.routers.vercel_response import VercelStreamResponse
from app.engine.query_filter import generate_filters
from app.workflows import create_workflow
......@@ -29,19 +31,22 @@ async def chat(
params = data.data or {}
workflow = create_workflow(
chat_history=messages,
params=params,
filters=filters,
)
event_handler = workflow.run(input=last_message_content, streaming=True)
return VercelStreamResponse(
request=request,
chat_data=data,
background_tasks=background_tasks,
event_handler=event_handler,
events=workflow.stream_events(),
handler = workflow.run(
user_msg=last_message_content,
chat_history=messages,
stream=True,
)
return StreamHandler.from_default(
handler=handler,
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
],
).vercel_stream()
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
......
import asyncio
import json
import logging
from typing import AsyncGenerator, Awaitable, List
from typing import AsyncGenerator
from aiostream import stream
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from llama_index.core.schema import NodeWithScore
from llama_index.core.agent.workflow.workflow_events import AgentStream
from llama_index.core.workflow import StopEvent
from app.api.routers.models import ChatData, Message
from app.api.services.suggestion import NextQuestionSuggestion
from app.api.callbacks.stream_handler import StreamHandler
logger = logging.getLogger("uvicorn")
class VercelStreamResponse(StreamingResponse):
"""
Base class to convert the response from the chat engine to the streaming format expected by Vercel
Converts preprocessed events into Vercel-compatible streaming response format.
"""
TEXT_PREFIX = "0:"
......@@ -25,136 +23,77 @@ class VercelStreamResponse(StreamingResponse):
def __init__(
self,
request: Request,
chat_data: ChatData,
background_tasks: BackgroundTasks,
stream_handler: StreamHandler,
*args,
**kwargs,
):
self.request = request
self.chat_data = chat_data
self.background_tasks = background_tasks
content = self.content_generator(*args, **kwargs)
super().__init__(content=content)
self.handler = stream_handler
super().__init__(content=self.content_generator())
async def content_generator(self, event_handler, events):
stream = self._create_stream(
self.request, self.chat_data, event_handler, events
)
is_stream_started = False
async def content_generator(self):
"""Generate Vercel-formatted content from preprocessed events."""
stream_started = False
try:
async with stream.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start the stream
yield self.convert_text("")
async for event in self.handler.stream_events():
if not stream_started:
# Start the stream with an empty message
stream_started = True
yield self.convert_text("")
# Handle different types of events
if isinstance(event, (AgentStream, StopEvent)):
async for chunk in self._stream_text(event):
await self.handler.accumulate_text(chunk)
yield self.convert_text(chunk)
elif isinstance(event, dict):
yield self.convert_data(event)
elif hasattr(event, "to_response"):
event_response = event.to_response()
yield self.convert_data(event_response)
else:
yield self.convert_data(event.model_dump())
yield output
except asyncio.CancelledError:
logger.warning("Workflow has been cancelled!")
logger.warning("Client cancelled the request!")
await self.handler.cancel_run()
except Exception as e:
logger.error(
f"Unexpected error in content_generator: {str(e)}", exc_info=True
)
yield self.convert_error(
"An unexpected error occurred while processing your request, preventing the creation of a final answer. Please try again."
)
finally:
await event_handler.cancel_run()
logger.info("The stream has been stopped!")
def _create_stream(
self,
request: Request,
chat_data: ChatData,
event_handler: Awaitable,
events: AsyncGenerator,
verbose: bool = True,
):
# Yield the text response
async def _chat_response_generator():
result = await event_handler
final_response = ""
if isinstance(result, AsyncGenerator):
async for token in result:
final_response += str(token.delta)
yield self.convert_text(token.delta)
else:
if hasattr(result, "response"):
content = result.response.message.content
if content:
for token in content:
final_response += str(token)
yield self.convert_text(token)
else:
final_response += str(result)
yield self.convert_text(result)
# Generate next questions if next question prompt is configured
question_data = await self._generate_next_questions(
chat_data.messages, final_response
)
if question_data:
yield self.convert_data(question_data)
# Yield the events from the event handler
async def _event_generator():
async for event in events:
event_response = event.to_response()
if verbose:
logger.debug(event_response)
if event_response is not None:
yield self.convert_data(event_response)
if event_response.get("type") == "sources":
self._process_response_nodes(event.nodes, self.background_tasks)
combine = stream.merge(_chat_response_generator(), _event_generator())
return combine
@staticmethod
def _process_response_nodes(
source_nodes: List[NodeWithScore],
background_tasks: BackgroundTasks,
):
try:
# Start background tasks to download documents from LlamaCloud if needed
from app.engine.service import LLamaCloudFileService # type: ignore
LLamaCloudFileService.download_files_from_nodes(
source_nodes, background_tasks
)
except ImportError:
logger.debug(
"LlamaCloud is not configured. Skipping post processing of nodes"
)
pass
logger.error(f"Error in stream response: {e}")
yield self.convert_error(str(e))
await self.handler.cancel_run()
async def _stream_text(
self, event: AgentStream | StopEvent
) -> AsyncGenerator[str, None]:
"""
Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result
"""
if isinstance(event, AgentStream):
yield self.convert_text(event.delta)
elif isinstance(event, StopEvent):
if isinstance(event.result, str):
yield event.result
elif isinstance(event.result, AsyncGenerator):
async for chunk in event.result:
if isinstance(chunk, str):
yield chunk
elif hasattr(chunk, "delta"):
yield chunk.delta
@classmethod
def convert_text(cls, token: str):
def convert_text(cls, token: str) -> str:
"""Convert text event to Vercel format."""
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod
def convert_data(cls, data: dict):
def convert_data(cls, data: dict) -> str:
"""Convert data event to Vercel format."""
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
@classmethod
def convert_error(cls, error: str):
def convert_error(cls, error: str) -> str:
"""Convert error event to Vercel format."""
error_str = json.dumps(error)
return f"{cls.ERROR_PREFIX}{error_str}\n"
@staticmethod
async def _generate_next_questions(chat_history: List[Message], response: str):
questions = await NextQuestionSuggestion.suggest_next_questions(
chat_history, response
)
if questions:
return {
"type": "suggested_questions",
"data": questions,
}
return None
......@@ -3,7 +3,6 @@ import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Optional
from app.workflows.events import AgentRunEvent, AgentRunEventType
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import (
......@@ -15,6 +14,8 @@ from llama_index.core.tools import (
from llama_index.core.workflow import Context
from pydantic import BaseModel, ConfigDict
from app.workflows.events import AgentRunEvent, AgentRunEventType
logger = logging.getLogger("uvicorn")
......@@ -51,7 +52,9 @@ class ChatWithToolsResponse(BaseModel):
assert self.generator is not None
full_response = ""
async for chunk in self.generator:
full_response += chunk.message.content
content = chunk.message.content
if content:
full_response += content
return full_response
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment