Skip to content
Snippets Groups Projects
Commit 6ba50233 authored by leehuwuj's avatar leehuwuj
Browse files

migrate form_filling to AgentWorkflow

parent 22e4be93
No related branches found
No related tags found
Loading
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional, Type
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
from llama_index.core.agent.workflow import (
AgentWorkflow,
FunctionAgent,
ReActAgent,
)
from llama_index.core.llms import LLM
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
def create_workflow(
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
) -> AgentWorkflow:
# Create query engine tool
index_config = IndexConfig(**params)
index = get_index(index_config)
......@@ -40,197 +29,60 @@ def create_workflow(
extractor_tool = configured_tools.get("extract_questions") # type: ignore
filling_tool = configured_tools.get("fill_form") # type: ignore
workflow = FormFillingWorkflow(
query_engine_tool=query_engine_tool,
extractor_tool=extractor_tool, # type: ignore
filling_tool=filling_tool, # type: ignore
if extractor_tool is None or filling_tool is None:
raise ValueError("Extractor and filling tools are required.")
agent_cls = _get_agent_cls_from_llm(Settings.llm)
extractor_agent = agent_cls(
name="extractor",
description="An agent that extracts missing cells from CSV files and generates questions to fill them.",
tools=[extractor_tool],
system_prompt="""
You are a helpful assistant who extracts missing cells from CSV files.
Only extract missing cells from CSV files and generate questions to fill them.
Always handoff the task to the `researcher` agent after extracting the questions.
""",
llm=Settings.llm,
can_handoff_to=["researcher"],
)
return workflow
class InputEvent(Event):
input: List[ChatMessage]
response: bool = False
class ExtractMissingCellsEvent(Event):
tool_calls: list[ToolSelection]
class FindAnswersEvent(Event):
tool_calls: list[ToolSelection]
class FillEvent(Event):
tool_calls: list[ToolSelection]
class FormFillingWorkflow(Workflow):
"""
A predefined workflow for filling missing cells in a CSV file.
Required tools:
- query_engine: A query engine to query for the answers to the questions.
- extract_question: Extract missing cells in a CSV file and generate questions to fill them.
- answer_question: Query for the answers to the questions.
Flow:
1. Extract missing cells in a CSV file and generate questions to fill them.
2. Query for the answers to the questions.
3. Fill the missing cells with the answers.
"""
_default_system_prompt = """
You are a helpful assistant who helps fill missing cells in a CSV file.
Only extract missing cells from CSV files.
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
"""
stream: bool = True
def __init__(
self,
query_engine_tool: Optional[QueryEngineTool],
extractor_tool: FunctionTool,
filling_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.query_engine_tool = query_engine_tool
self.extractor_tool = extractor_tool
self.filling_tool = filling_tool
if self.extractor_tool is None or self.filling_tool is None:
raise ValueError("Extractor and filling tools are required.")
self.tools = [self.extractor_tool, self.filling_tool]
if self.query_engine_tool is not None:
self.tools.append(self.query_engine_tool) # type: ignore
self.llm: FunctionCallingLLM = llm or Settings.llm
if not isinstance(self.llm, FunctionCallingLLM):
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg", "")
chat_history = ev.get("chat_history", [])
if chat_history:
self.memory.put_messages(chat_history)
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)
researcher_agent = agent_cls(
name="researcher",
description="An agent that finds answers to questions about missing cells.",
tools=[query_engine_tool] if query_engine_tool else [],
system_prompt="""
You are a researcher who finds answers to questions about missing cells.
Only use provided data - never make up any information yourself. Use N/A if an answer is not found.
Always handoff the task to the `processor` agent after finding the answers.
""",
llm=Settings.llm,
can_handoff_to=["processor"],
)
return InputEvent(input=self.memory.get())
processor_agent = agent_cls(
name="processor",
description="An agent that fills missing cells with found answers.",
tools=[filling_tool],
system_prompt="""
You are a processor who fills missing cells with found answers.
Fill N/A for any missing answers.
After filling the cells, tell the user about the results or any issues encountered.
""",
llm=Settings.llm,
)
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
ev: InputEvent,
) -> ExtractMissingCellsEvent | FillEvent | StopEvent:
"""
Handle an LLM input and decide the next step.
"""
chat_history: list[ChatMessage] = ev.input
response = await chat_with_tools(
self.llm,
self.tools,
chat_history,
)
if not response.has_tool_calls():
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.extractor_tool.metadata.name:
return ExtractMissingCellsEvent(tool_calls=response.tool_calls)
case self.query_engine_tool.metadata.name:
return FindAnswersEvent(tool_calls=response.tool_calls)
case self.filling_tool.metadata.name:
return FillEvent(tool_calls=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
workflow = AgentWorkflow(
agents=[extractor_agent, researcher_agent, processor_agent],
root_agent="extractor",
verbose=True,
)
@step()
async def extract_missing_cells(
self, ctx: Context, ev: ExtractMissingCellsEvent
) -> InputEvent | FindAnswersEvent:
"""
Extract missing cells in a CSV file and generate questions to fill them.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Extractor",
msg="Extracting missing cells",
)
)
# Call the extract questions tool
tool_messages = await call_tools(
agent_name="Extractor",
tools=[self.extractor_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
return workflow
@step()
async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
"""
Call answer questions tool to query for the answers to the questions.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Finding answers for missing cells",
)
)
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
@step()
async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent:
"""
Call fill cells tool to fill the missing cells with the answers.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Processor",
msg="Filling missing cells",
)
)
tool_messages = await call_tools(
agent_name="Processor",
tools=[self.filling_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
def _get_agent_cls_from_llm(llm: LLM) -> Type[FunctionAgent | ReActAgent]:
if llm.metadata.is_function_calling_model:
return FunctionAgent
else:
return ReActAgent
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment