Skip to content
Snippets Groups Projects
Unverified Commit 6edea6af authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

enhance workflow code for Python (#412)


* enhance workflow shared code

* fix streaming

* refactor code

* add missing helper

* update

* update form filling

* add filters

* simplify the code

* simplify the code

* simplify the code

* update form filling

* update e2e

* update function calling agent

* fix unneeded condition

* Create light-parrots-work.md

* revert change on using functioncallingagent

* update readme

* clean code

* extract call one tool function

* update for blog use case

* fix streaming

* fix e2e

* fix missing await

* improve tools code

* improve assertion code

* skip form filling test for TS framework

* update for tools helper

---------

Co-authored-by: default avatarMarcus Schiesser <mail@marcusschiesser.de>
parent d79d1652
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,8 @@ from app.api.routers.models import (
ChatData,
)
from app.api.routers.vercel_response import VercelStreamResponse
from app.engine.engine import get_chat_engine
from app.engine.query_filter import generate_filters
from app.workflows import create_workflow
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
chat_router = r = APIRouter()
......@@ -22,19 +23,20 @@ async def chat(
last_message_content = data.get_last_message_content()
messages = data.get_history_messages(include_agent_messages=True)
# The chat API supports passing private document filters and chat params
# but agent workflow does not support them yet
# ignore chat params and use all documents for now
# TODO: generate filters based on doc_ids
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
engine = get_chat_engine(chat_history=messages, params=params)
event_handler = engine.run(input=last_message_content, streaming=True)
workflow = create_workflow(
chat_history=messages, params=params, filters=filters
)
event_handler = workflow.run(input=last_message_content, streaming=True)
return VercelStreamResponse(
request=request,
chat_data=data,
event_handler=event_handler,
events=engine.stream_events(),
events=workflow.stream_events(),
)
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
......
from enum import Enum
from typing import Optional
from llama_index.core.workflow import Event
class AgentRunEventType(Enum):
TEXT = "text"
PROGRESS = "progress"
class AgentRunEvent(Event):
name: str
msg: str
event_type: AgentRunEventType = AgentRunEventType.TEXT
data: Optional[dict] = None
def to_response(self) -> dict:
return {
"type": "agent",
"data": {
"agent": self.name,
"type": self.event_type.value,
"text": self.msg,
"data": self.data,
},
}
from typing import Any, List, Optional
from app.workflows.events import AgentRunEvent
from app.workflows.tools import ToolCallResponse, call_tools, chat_with_tools
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.settings import Settings
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
input: ToolCallResponse
class FunctionCallingAgent(Workflow):
"""
A simple workflow to request LLM with tools independently.
You can share the previous chat history to provide the context for the LLM.
"""
def __init__(
self,
*args: Any,
llm: FunctionCallingLLM | None = None,
chat_history: Optional[List[ChatMessage]] = None,
tools: List[BaseTool] | None = None,
system_prompt: str | None = None,
verbose: bool = False,
timeout: float = 360.0,
name: str,
write_events: bool = True,
**kwargs: Any,
) -> None:
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs) # type: ignore
self.tools = tools or []
self.name = name
self.write_events = write_events
if llm is None:
llm = Settings.llm
self.llm = llm
if not self.llm.metadata.is_function_calling_model:
raise ValueError("The provided LLM must support function calling.")
self.system_prompt = system_prompt
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=chat_history
)
self.sources = [] # type: ignore
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
# clear sources
self.sources = []
# set streaming
ctx.data["streaming"] = getattr(ev, "streaming", False)
# set system prompt
if self.system_prompt is not None:
system_msg = ChatMessage(role="system", content=self.system_prompt)
self.memory.put(system_msg)
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
if self.write_events:
ctx.write_event_to_stream(
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input(
self,
ctx: Context,
ev: InputEvent,
) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await chat_with_tools(
self.llm,
self.tools,
chat_history,
)
is_tool_call = isinstance(response, ToolCallResponse)
if not is_tool_call:
if ctx.data["streaming"]:
return StopEvent(result=response)
else:
full_response = ""
async for chunk in response.generator:
full_response += chunk.message.content
return StopEvent(result=full_response)
return ToolCallEvent(input=response)
@step()
async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent:
tool_calls = ev.input.tool_calls
tool_call_message = ev.input.tool_call_message
self.memory.put(tool_call_message)
tool_messages = await call_tools(self.name, self.tools, ctx, tool_calls)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Optional
from app.workflows.events import AgentRunEvent, AgentRunEventType
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import (
BaseTool,
FunctionTool,
ToolOutput,
ToolSelection,
)
from llama_index.core.workflow import Context
from pydantic import BaseModel, ConfigDict
logger = logging.getLogger("uvicorn")
class ContextAwareTool(FunctionTool, ABC):
@abstractmethod
async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore
pass
class ResponseGenerator(BaseModel):
"""
A response generator from chat_with_tools.
"""
generator: AsyncGenerator[ChatResponse | None, None]
model_config = ConfigDict(arbitrary_types_allowed=True)
class ChatWithToolsResponse(BaseModel):
"""
A tool call response from chat_with_tools.
"""
tool_calls: Optional[list[ToolSelection]]
tool_call_message: Optional[ChatMessage]
generator: Optional[AsyncGenerator[ChatResponse | None, None]]
model_config = ConfigDict(arbitrary_types_allowed=True)
def is_calling_different_tools(self) -> bool:
tool_names = {tool_call.tool_name for tool_call in self.tool_calls}
return len(tool_names) > 1
def has_tool_calls(self) -> bool:
return self.tool_calls is not None and len(self.tool_calls) > 0
def tool_name(self) -> str:
assert self.has_tool_calls()
assert not self.is_calling_different_tools()
return self.tool_calls[0].tool_name
async def full_response(self) -> str:
assert self.generator is not None
full_response = ""
async for chunk in self.generator:
full_response += chunk.message.content
return full_response
async def chat_with_tools( # type: ignore
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> ChatWithToolsResponse:
"""
Request LLM to call tools or not.
This function doesn't change the memory.
"""
generator = _tool_call_generator(llm, tools, chat_history)
is_tool_call = await generator.__anext__()
if is_tool_call:
# Last chunk is the full response
# Wait for the last chunk
full_response = None
async for chunk in generator:
full_response = chunk
assert isinstance(full_response, ChatResponse)
return ChatWithToolsResponse(
tool_calls=llm.get_tool_calls_from_response(full_response),
tool_call_message=full_response.message,
generator=None,
)
else:
return ChatWithToolsResponse(
tool_calls=None,
tool_call_message=None,
generator=generator,
)
async def call_tools(
ctx: Context,
agent_name: str,
tools: list[BaseTool],
tool_calls: list[ToolSelection],
emit_agent_events: bool = True,
) -> list[ChatMessage]:
if len(tool_calls) == 0:
return []
tools_by_name = {tool.metadata.get_name(): tool for tool in tools}
if len(tool_calls) == 1:
return [
await call_tool(
ctx,
tools_by_name[tool_calls[0].tool_name],
tool_calls[0],
lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
)
),
)
]
# Multiple tool calls, show progress
tool_msgs: list[ChatMessage] = []
progress_id = str(uuid.uuid4())
total_steps = len(tool_calls)
if emit_agent_events:
ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=f"Making {total_steps} tool calls",
)
)
for i, tool_call in enumerate(tool_calls):
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
tool_msgs.append(
ChatMessage(
role=MessageRole.ASSISTANT,
content=f"Tool {tool_call.tool_name} does not exist",
)
)
continue
tool_msg = await call_tool(
ctx,
tool,
tool_call,
event_emitter=lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
event_type=AgentRunEventType.PROGRESS,
data={
"id": progress_id,
"total": total_steps,
"current": i,
},
)
),
)
tool_msgs.append(tool_msg)
return tool_msgs
async def call_tool(
ctx: Context,
tool: BaseTool,
tool_call: ToolSelection,
event_emitter: Optional[Callable[[str], None]],
) -> ChatMessage:
if event_emitter:
event_emitter(
f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}"
)
try:
if isinstance(tool, ContextAwareTool):
if ctx is None:
raise ValueError("Context is required for context aware tool")
# inject context for calling an context aware tool
response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
else:
response = await tool.acall(**tool_call.tool_kwargs) # type: ignore
return ChatMessage(
role=MessageRole.TOOL,
content=str(response.raw_output),
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
except Exception as e:
logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}")
if event_emitter:
event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}")
return ChatMessage(
role=MessageRole.TOOL,
content=f"Error: {str(e)}",
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
async def _tool_call_generator(
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> AsyncGenerator[ChatResponse | bool, None]:
response_stream = await llm.astream_chat_with_tools(
tools,
chat_history=chat_history,
allow_parallel_tool_calls=False,
)
full_response = None
yielded_indicator = False
async for chunk in response_stream:
if "tool_calls" not in chunk.message.additional_kwargs:
# Yield a boolean to indicate whether the response is a tool call
if not yielded_indicator:
yield False
yielded_indicator = True
# if not a tool call, yield the chunks!
yield chunk # type: ignore
elif not yielded_indicator:
# Yield the indicator for a tool call
yield True
yielded_indicator = True
full_response = chunk
if full_response:
yield full_response # type: ignore
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment