From 125a06abff14030c006dfec2d11c7822881a758b Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Mon, 17 Feb 2025 10:57:49 -0600
Subject: [PATCH] update from scratch agent examples (#17843)

---
 .../workflow/function_calling_agent.ipynb     | 150 +++++++++++++--
 docs/docs/examples/workflow/react_agent.ipynb | 173 ++++++++++++++----
 2 files changed, 269 insertions(+), 54 deletions(-)

diff --git a/docs/docs/examples/workflow/function_calling_agent.ipynb b/docs/docs/examples/workflow/function_calling_agent.ipynb
index bf433819ee..4a3c82495b 100644
--- a/docs/docs/examples/workflow/function_calling_agent.ipynb
+++ b/docs/docs/examples/workflow/function_calling_agent.ipynb
@@ -75,8 +75,9 @@
     "\n",
     "To handle these steps, we need to define a few events:\n",
     "1. An event to handle new messages and prepare the chat history\n",
-    "2. An event to trigger tool calls\n",
-    "3. An event to handle the results of tool calls\n",
+    "2. An event to handle streaming responses\n",
+    "3. An event to trigger tool calls\n",
+    "4. An event to handle the results of tool calls\n",
     "\n",
     "The other steps will use the built-in `StartEvent` and `StopEvent` events."
    ]
@@ -96,6 +97,10 @@
     "    input: list[ChatMessage]\n",
     "\n",
     "\n",
+    "class StreamEvent(Event):\n",
+    "    delta: str\n",
+    "\n",
+    "\n",
     "class ToolCallEvent(Event):\n",
     "    tool_calls: list[ToolSelection]\n",
     "\n",
@@ -126,7 +131,14 @@
     "from llama_index.core.llms.function_calling import FunctionCallingLLM\n",
     "from llama_index.core.memory import ChatMemoryBuffer\n",
     "from llama_index.core.tools.types import BaseTool\n",
-    "from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step\n",
+    "from llama_index.core.workflow import (\n",
+    "    Context,\n",
+    "    Workflow,\n",
+    "    StartEvent,\n",
+    "    StopEvent,\n",
+    "    step,\n",
+    ")\n",
+    "from llama_index.llms.openai import OpenAI\n",
     "\n",
     "\n",
     "class FuncationCallingAgent(Workflow):\n",
@@ -143,51 +155,71 @@
     "        self.llm = llm or OpenAI()\n",
     "        assert self.llm.metadata.is_function_calling_model\n",
     "\n",
-    "        self.memory = ChatMemoryBuffer.from_defaults(llm=llm)\n",
-    "        self.sources = []\n",
-    "\n",
     "    @step\n",
-    "    async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:\n",
+    "    async def prepare_chat_history(\n",
+    "        self, ctx: Context, ev: StartEvent\n",
+    "    ) -> InputEvent:\n",
     "        # clear sources\n",
-    "        self.sources = []\n",
+    "        await ctx.set(\"sources\", [])\n",
+    "\n",
+    "        # check if memory is setup\n",
+    "        memory = await ctx.get(\"memory\", default=None)\n",
+    "        if not memory:\n",
+    "            memory = ChatMemoryBuffer.from_defaults(llm=self.llm)\n",
     "\n",
     "        # get user input\n",
     "        user_input = ev.input\n",
     "        user_msg = ChatMessage(role=\"user\", content=user_input)\n",
-    "        self.memory.put(user_msg)\n",
+    "        memory.put(user_msg)\n",
     "\n",
     "        # get chat history\n",
-    "        chat_history = self.memory.get()\n",
+    "        chat_history = memory.get()\n",
+    "\n",
+    "        # update context\n",
+    "        await ctx.set(\"memory\", memory)\n",
+    "\n",
     "        return InputEvent(input=chat_history)\n",
     "\n",
     "    @step\n",
     "    async def handle_llm_input(\n",
-    "        self, ev: InputEvent\n",
+    "        self, ctx: Context, ev: InputEvent\n",
     "    ) -> ToolCallEvent | StopEvent:\n",
     "        chat_history = ev.input\n",
     "\n",
-    "        response = await self.llm.achat_with_tools(\n",
+    "        # stream the response\n",
+    "        response_stream = await self.llm.astream_chat_with_tools(\n",
     "            self.tools, chat_history=chat_history\n",
     "        )\n",
-    "        self.memory.put(response.message)\n",
+    "        async for response in response_stream:\n",
+    "            ctx.write_event_to_stream(StreamEvent(delta=response.delta or \"\"))\n",
+    "\n",
+    "        # save the final response, which should have all content\n",
+    "        memory = await ctx.get(\"memory\")\n",
+    "        memory.put(response.message)\n",
+    "        await ctx.set(\"memory\", memory)\n",
     "\n",
+    "        # get tool calls\n",
     "        tool_calls = self.llm.get_tool_calls_from_response(\n",
     "            response, error_on_no_tool_call=False\n",
     "        )\n",
     "\n",
     "        if not tool_calls:\n",
+    "            sources = await ctx.get(\"sources\", default=[])\n",
     "            return StopEvent(\n",
-    "                result={\"response\": response, \"sources\": [*self.sources]}\n",
+    "                result={\"response\": response, \"sources\": [*sources]}\n",
     "            )\n",
     "        else:\n",
     "            return ToolCallEvent(tool_calls=tool_calls)\n",
     "\n",
     "    @step\n",
-    "    async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:\n",
+    "    async def handle_tool_calls(\n",
+    "        self, ctx: Context, ev: ToolCallEvent\n",
+    "    ) -> InputEvent:\n",
     "        tool_calls = ev.tool_calls\n",
     "        tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}\n",
     "\n",
     "        tool_msgs = []\n",
+    "        sources = await ctx.get(\"sources\", default=[])\n",
     "\n",
     "        # call tools -- safely!\n",
     "        for tool_call in tool_calls:\n",
@@ -208,7 +240,7 @@
     "\n",
     "            try:\n",
     "                tool_output = tool(**tool_call.tool_kwargs)\n",
-    "                self.sources.append(tool_output)\n",
+    "                sources.append(tool_output)\n",
     "                tool_msgs.append(\n",
     "                    ChatMessage(\n",
     "                        role=\"tool\",\n",
@@ -225,10 +257,15 @@
     "                    )\n",
     "                )\n",
     "\n",
+    "        # update memory\n",
+    "        memory = await ctx.get(\"memory\")\n",
     "        for msg in tool_msgs:\n",
-    "            self.memory.put(msg)\n",
+    "            memory.put(msg)\n",
     "\n",
-    "        chat_history = self.memory.get()\n",
+    "        await ctx.set(\"sources\", sources)\n",
+    "        await ctx.set(\"memory\", memory)\n",
+    "\n",
+    "        chat_history = memory.get()\n",
     "        return InputEvent(input=chat_history)"
    ]
   },
@@ -345,6 +382,15 @@
     "ret = await agent.run(input=\"What is (2123 + 2321) * 312?\")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chat History\n",
+    "\n",
+    "By default, the workflow is creating a fresh `Context` for each run. This means that the chat history is not preserved between runs. However, we can pass our own `Context` to the workflow to preserve chat history."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -354,13 +400,79 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "assistant: The result of \\((2123 + 2321) \\times 312\\) is \\(1,386,528\\).\n"
+      "Running step prepare_chat_history\n",
+      "Step prepare_chat_history produced event InputEvent\n",
+      "Running step handle_llm_input\n",
+      "Step handle_llm_input produced event StopEvent\n",
+      "assistant: Hello, Logan! How can I assist you today?\n",
+      "Running step prepare_chat_history\n",
+      "Step prepare_chat_history produced event InputEvent\n",
+      "Running step handle_llm_input\n",
+      "Step handle_llm_input produced event StopEvent\n",
+      "assistant: Your name is Logan.\n"
      ]
     }
    ],
    "source": [
+    "from llama_index.core.workflow import Context\n",
+    "\n",
+    "ctx = Context(agent)\n",
+    "\n",
+    "ret = await agent.run(input=\"Hello! My name is Logan.\", ctx=ctx)\n",
+    "print(ret[\"response\"])\n",
+    "\n",
+    "ret = await agent.run(input=\"What is my name?\", ctx=ctx)\n",
     "print(ret[\"response\"])"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Streaming\n",
+    "\n",
+    "Using the `handler` returned from the `.run()` method, we can also access the streaming events."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Once upon a time in a quaint little village, there lived a curious cat named Whiskers. Whiskers was no ordinary cat; he had a beautiful coat of orange and white fur that shimmered in the sunlight, and his emerald green eyes sparkled with mischief.\n",
+      "\n",
+      "Every day, Whiskers would explore the village, visiting the bakery for a whiff of freshly baked bread and the flower shop to sniff the colorful blooms. The villagers adored him, often leaving out little treats for their favorite feline.\n",
+      "\n",
+      "One sunny afternoon, while wandering near the edge of the village, Whiskers stumbled upon a hidden path that led into the woods. His curiosity piqued, he decided to follow the path, which was lined with tall trees and vibrant wildflowers. As he ventured deeper, he heard a soft, melodic sound that seemed to beckon him.\n",
+      "\n",
+      "Following the enchanting music, Whiskers soon found himself in a clearing where a group of woodland creatures had gathered. They were having a grand celebration, complete with dancing, singing, and a feast of berries and nuts. The animals welcomed Whiskers with open paws, inviting him to join their festivities.\n",
+      "\n",
+      "Whiskers, delighted by the warmth and joy of his new friends, danced and played until the sun began to set. As the sky turned shades of pink and orange, he realized it was time to return home. The woodland creatures gifted him a small, sparkling acorn as a token of their friendship.\n",
+      "\n",
+      "From that day on, Whiskers would often visit the clearing, sharing stories of the village and enjoying the company of his woodland friends. He learned that adventure and friendship could be found in the most unexpected places, and he cherished every moment spent in the magical woods.\n",
+      "\n",
+      "And so, Whiskers continued to live his life filled with curiosity, laughter, and the warmth of friendship, reminding everyone that sometimes, the best adventures are just a whisker away."
+     ]
+    }
+   ],
+   "source": [
+    "agent = FuncationCallingAgent(\n",
+    "    llm=OpenAI(model=\"gpt-4o-mini\"), tools=tools, timeout=120, verbose=False\n",
+    ")\n",
+    "\n",
+    "handler = agent.run(input=\"Hello! Write me a short story about a cat.\")\n",
+    "\n",
+    "async for event in handler.stream_events():\n",
+    "    if isinstance(event, StreamEvent):\n",
+    "        print(event.delta, end=\"\", flush=True)\n",
+    "\n",
+    "response = await handler\n",
+    "# print(ret[\"response\"])"
+   ]
   }
  ],
  "metadata": {
diff --git a/docs/docs/examples/workflow/react_agent.ipynb b/docs/docs/examples/workflow/react_agent.ipynb
index 09fd5d64dc..8cbcf5b37f 100644
--- a/docs/docs/examples/workflow/react_agent.ipynb
+++ b/docs/docs/examples/workflow/react_agent.ipynb
@@ -30,7 +30,7 @@
    "source": [
     "import os\n",
     "\n",
-    "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj--...\""
+    "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\""
    ]
   },
   {
@@ -115,9 +115,10 @@
     "\n",
     "To handle these steps, we need to define a few events:\n",
     "1. An event to handle new messages and prepare the chat history\n",
-    "2. An event to prompt the LLM with the react prompt\n",
-    "3. An event to trigger tool calls, if any\n",
-    "4. An event to handle the results of tool calls, if any\n",
+    "2. An event to stream the LLM response\n",
+    "3. An event to prompt the LLM with the react prompt\n",
+    "4. An event to trigger tool calls, if any\n",
+    "5. An event to handle the results of tool calls, if any\n",
     "\n",
     "The other steps will use the built-in `StartEvent` and `StopEvent` events.\n",
     "\n",
@@ -143,6 +144,10 @@
     "    input: list[ChatMessage]\n",
     "\n",
     "\n",
+    "class StreamEvent(Event):\n",
+    "    delta: str\n",
+    "\n",
+    "\n",
     "class ToolCallEvent(Event):\n",
     "    tool_calls: list[ToolSelection]\n",
     "\n",
@@ -199,29 +204,33 @@
     "    ) -> None:\n",
     "        super().__init__(*args, **kwargs)\n",
     "        self.tools = tools or []\n",
-    "\n",
     "        self.llm = llm or OpenAI()\n",
-    "\n",
-    "        self.memory = ChatMemoryBuffer.from_defaults(llm=llm)\n",
     "        self.formatter = ReActChatFormatter.from_defaults(\n",
     "            context=extra_context or \"\"\n",
     "        )\n",
     "        self.output_parser = ReActOutputParser()\n",
-    "        self.sources = []\n",
     "\n",
     "    @step\n",
     "    async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent:\n",
     "        # clear sources\n",
-    "        self.sources = []\n",
+    "        await ctx.set(\"sources\", [])\n",
+    "\n",
+    "        # init memory if needed\n",
+    "        memory = await ctx.get(\"memory\", default=None)\n",
+    "        if not memory:\n",
+    "            memory = ChatMemoryBuffer.from_defaults(llm=self.llm)\n",
     "\n",
     "        # get user input\n",
     "        user_input = ev.input\n",
     "        user_msg = ChatMessage(role=\"user\", content=user_input)\n",
-    "        self.memory.put(user_msg)\n",
+    "        memory.put(user_msg)\n",
     "\n",
     "        # clear current reasoning\n",
     "        await ctx.set(\"current_reasoning\", [])\n",
     "\n",
+    "        # set memory\n",
+    "        await ctx.set(\"memory\", memory)\n",
+    "\n",
     "        return PrepEvent()\n",
     "\n",
     "    @step\n",
@@ -229,8 +238,11 @@
     "        self, ctx: Context, ev: PrepEvent\n",
     "    ) -> InputEvent:\n",
     "        # get chat history\n",
-    "        chat_history = self.memory.get()\n",
+    "        memory = await ctx.get(\"memory\")\n",
+    "        chat_history = memory.get()\n",
     "        current_reasoning = await ctx.get(\"current_reasoning\", default=[])\n",
+    "\n",
+    "        # format the prompt with react instructions\n",
     "        llm_input = self.formatter.format(\n",
     "            self.tools, chat_history, current_reasoning=current_reasoning\n",
     "        )\n",
@@ -241,27 +253,33 @@
     "        self, ctx: Context, ev: InputEvent\n",
     "    ) -> ToolCallEvent | StopEvent:\n",
     "        chat_history = ev.input\n",
+    "        current_reasoning = await ctx.get(\"current_reasoning\", default=[])\n",
+    "        memory = await ctx.get(\"memory\")\n",
     "\n",
-    "        response = await self.llm.achat(chat_history)\n",
+    "        response_gen = await self.llm.astream_chat(chat_history)\n",
+    "        async for response in response_gen:\n",
+    "            ctx.write_event_to_stream(StreamEvent(delta=response.delta or \"\"))\n",
     "\n",
     "        try:\n",
     "            reasoning_step = self.output_parser.parse(response.message.content)\n",
-    "            (await ctx.get(\"current_reasoning\", default=[])).append(\n",
-    "                reasoning_step\n",
-    "            )\n",
+    "            current_reasoning.append(reasoning_step)\n",
+    "\n",
     "            if reasoning_step.is_done:\n",
-    "                self.memory.put(\n",
+    "                memory.put(\n",
     "                    ChatMessage(\n",
     "                        role=\"assistant\", content=reasoning_step.response\n",
     "                    )\n",
     "                )\n",
+    "                await ctx.set(\"memory\", memory)\n",
+    "                await ctx.set(\"current_reasoning\", current_reasoning)\n",
+    "\n",
+    "                sources = await ctx.get(\"sources\", default=[])\n",
+    "\n",
     "                return StopEvent(\n",
     "                    result={\n",
     "                        \"response\": reasoning_step.response,\n",
-    "                        \"sources\": [*self.sources],\n",
-    "                        \"reasoning\": await ctx.get(\n",
-    "                            \"current_reasoning\", default=[]\n",
-    "                        ),\n",
+    "                        \"sources\": [sources],\n",
+    "                        \"reasoning\": current_reasoning,\n",
     "                    }\n",
     "                )\n",
     "            elif isinstance(reasoning_step, ActionReasoningStep):\n",
@@ -277,11 +295,12 @@
     "                    ]\n",
     "                )\n",
     "        except Exception as e:\n",
-    "            (await ctx.get(\"current_reasoning\", default=[])).append(\n",
+    "            current_reasoning.append(\n",
     "                ObservationReasoningStep(\n",
     "                    observation=f\"There was an error in parsing my reasoning: {e}\"\n",
     "                )\n",
     "            )\n",
+    "            await ctx.set(\"current_reasoning\", current_reasoning)\n",
     "\n",
     "        # if no tool calls or final response, iterate again\n",
     "        return PrepEvent()\n",
@@ -292,12 +311,14 @@
     "    ) -> PrepEvent:\n",
     "        tool_calls = ev.tool_calls\n",
     "        tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}\n",
+    "        current_reasoning = await ctx.get(\"current_reasoning\", default=[])\n",
+    "        sources = await ctx.get(\"sources\", default=[])\n",
     "\n",
     "        # call tools -- safely!\n",
     "        for tool_call in tool_calls:\n",
     "            tool = tools_by_name.get(tool_call.tool_name)\n",
     "            if not tool:\n",
-    "                (await ctx.get(\"current_reasoning\", default=[])).append(\n",
+    "                current_reasoning.append(\n",
     "                    ObservationReasoningStep(\n",
     "                        observation=f\"Tool {tool_call.tool_name} does not exist\"\n",
     "                    )\n",
@@ -306,17 +327,21 @@
     "\n",
     "            try:\n",
     "                tool_output = tool(**tool_call.tool_kwargs)\n",
-    "                self.sources.append(tool_output)\n",
-    "                (await ctx.get(\"current_reasoning\", default=[])).append(\n",
+    "                sources.append(tool_output)\n",
+    "                current_reasoning.append(\n",
     "                    ObservationReasoningStep(observation=tool_output.content)\n",
     "                )\n",
     "            except Exception as e:\n",
-    "                (await ctx.get(\"current_reasoning\", default=[])).append(\n",
+    "                current_reasoning.append(\n",
     "                    ObservationReasoningStep(\n",
     "                        observation=f\"Error calling tool {tool.metadata.get_name()}: {e}\"\n",
     "                    )\n",
     "                )\n",
     "\n",
+    "        # save new state in context\n",
+    "        await ctx.set(\"sources\", sources)\n",
+    "        await ctx.set(\"current_reasoning\", current_reasoning)\n",
+    "\n",
     "        # prep the next iteraiton\n",
     "        return PrepEvent()"
    ]
@@ -388,7 +413,7 @@
     "]\n",
     "\n",
     "agent = ReActAgent(\n",
-    "    llm=OpenAI(model=\"gpt-4o-mini\"), tools=tools, timeout=120, verbose=True\n",
+    "    llm=OpenAI(model=\"gpt-4o\"), tools=tools, timeout=120, verbose=True\n",
     ")\n",
     "\n",
     "ret = await agent.run(input=\"Hello!\")"
@@ -403,8 +428,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Hello! How can I assist you today?\n",
-      "```\n"
+      "Hello! How can I assist you today?\n"
      ]
     }
    ],
@@ -425,13 +449,7 @@
       "Step new_user_msg produced event PrepEvent\n",
       "Running step prepare_chat_history\n",
       "Step prepare_chat_history produced event InputEvent\n",
-      "Running step handle_llm_input\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
+      "Running step handle_llm_input\n",
       "Step handle_llm_input produced event ToolCallEvent\n",
       "Running step handle_tool_calls\n",
       "Step handle_tool_calls produced event PrepEvent\n",
@@ -468,6 +486,91 @@
    "source": [
     "print(ret[\"response\"])"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chat History\n",
+    "\n",
+    "By default, the workflow is creating a fresh `Context` for each run. This means that the chat history is not preserved between runs. However, we can pass our own `Context` to the workflow to preserve chat history."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Running step new_user_msg\n",
+      "Step new_user_msg produced event PrepEvent\n",
+      "Running step prepare_chat_history\n",
+      "Step prepare_chat_history produced event InputEvent\n",
+      "Running step handle_llm_input\n",
+      "Step handle_llm_input produced event StopEvent\n",
+      "Hello, Logan! How can I assist you today?\n",
+      "Running step new_user_msg\n",
+      "Step new_user_msg produced event PrepEvent\n",
+      "Running step prepare_chat_history\n",
+      "Step prepare_chat_history produced event InputEvent\n",
+      "Running step handle_llm_input\n",
+      "Step handle_llm_input produced event StopEvent\n",
+      "Your name is Logan.\n"
+     ]
+    }
+   ],
+   "source": [
+    "from llama_index.core.workflow import Context\n",
+    "\n",
+    "ctx = Context(agent)\n",
+    "\n",
+    "ret = await agent.run(input=\"Hello! My name is Logan\", ctx=ctx)\n",
+    "print(ret[\"response\"])\n",
+    "\n",
+    "ret = await agent.run(input=\"What is my name?\", ctx=ctx)\n",
+    "print(ret[\"response\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Streaming\n",
+    "\n",
+    "We can also access the streaming response from the LLM, using the `handler` object returned from the `.run()` method."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Thought: The current language of the user is: English. I cannot use a tool to help me answer the question.\n",
+      "Answer: Why don't scientists trust atoms? Because they make up everything!"
+     ]
+    }
+   ],
+   "source": [
+    "agent = ReActAgent(\n",
+    "    llm=OpenAI(model=\"gpt-4o\"), tools=tools, timeout=120, verbose=False\n",
+    ")\n",
+    "\n",
+    "handler = agent.run(input=\"Hello! Tell me a joke.\")\n",
+    "\n",
+    "async for event in handler.stream_events():\n",
+    "    if isinstance(event, StreamEvent):\n",
+    "        print(event.delta, end=\"\", flush=True)\n",
+    "\n",
+    "response = await handler\n",
+    "# print(response)"
+   ]
   }
  ],
  "metadata": {
-- 
GitLab