From a088e391160ee947787198a06096ac9fdaaa675f Mon Sep 17 00:00:00 2001
From: Jerry Liu <jerryjliu98@gmail.com>
Date: Mon, 25 Mar 2024 13:00:20 -0700
Subject: [PATCH] add function calling LLM base class  (#12222)

---
 .../core/agent/function_calling/step.py       |   4 +-
 .../llama_index/core/llms/function_calling.py | 165 ++++++++++++++++++
 llama-index-core/llama_index/core/llms/llm.py |  45 +----
 .../llama_index/llms/mistralai/base.py        | 109 +-----------
 4 files changed, 177 insertions(+), 146 deletions(-)
 create mode 100644 llama-index-core/llama_index/core/llms/function_calling.py

diff --git a/llama-index-core/llama_index/core/agent/function_calling/step.py b/llama-index-core/llama_index/core/agent/function_calling/step.py
index f6d1444b42..69726c7631 100644
--- a/llama-index-core/llama_index/core/agent/function_calling/step.py
+++ b/llama-index-core/llama_index/core/agent/function_calling/step.py
@@ -246,7 +246,7 @@ class FunctionCallingAgentWorker(BaseAgentWorker):
             verbose=self._verbose,
             allow_parallel_tool_calls=self.allow_parallel_tool_calls,
         )
-        tool_calls = self._llm._get_tool_calls_from_response(
+        tool_calls = self._llm.get_tool_calls_from_response(
             response, error_on_no_tool_call=False
         )
         if not self.allow_parallel_tool_calls and len(tool_calls) > 1:
@@ -314,7 +314,7 @@ class FunctionCallingAgentWorker(BaseAgentWorker):
             verbose=self._verbose,
             allow_parallel_tool_calls=self.allow_parallel_tool_calls,
         )
-        tool_calls = self._llm._get_tool_calls_from_response(
+        tool_calls = self._llm.get_tool_calls_from_response(
             response, error_on_no_tool_call=False
         )
         if not self.allow_parallel_tool_calls and len(tool_calls) > 1:
diff --git a/llama-index-core/llama_index/core/llms/function_calling.py b/llama-index-core/llama_index/core/llms/function_calling.py
new file mode 100644
index 0000000000..c7616c4757
--- /dev/null
+++ b/llama-index-core/llama_index/core/llms/function_calling.py
@@ -0,0 +1,165 @@
+from typing import (
+    Any,
+    List,
+    Optional,
+    Union,
+    TYPE_CHECKING,
+)
+import asyncio
+
+from llama_index.core.base.llms.types import (
+    ChatMessage,
+)
+from llama_index.core.llms.llm import LLM
+
+from llama_index.core.base.llms.types import (
+    ChatMessage,
+    ChatResponse,
+)
+from llama_index.core.llms.llm import ToolSelection
+
+if TYPE_CHECKING:
+    from llama_index.core.chat_engine.types import AgentChatResponse
+    from llama_index.core.tools.types import BaseTool
+
+
+class FunctionCallingLLM(LLM):
+    """
+    Function calling LLMs are LLMs that support function calling.
+    They support an expanded range of capabilities.
+
+    """
+
+    def chat_with_tools(
+        self,
+        tools: List["BaseTool"],
+        user_msg: Optional[Union[str, ChatMessage]] = None,
+        chat_history: Optional[List[ChatMessage]] = None,
+        verbose: bool = False,
+        allow_parallel_tool_calls: bool = False,
+        **kwargs: Any,
+    ) -> ChatResponse:
+        """Predict and call the tool."""
+        raise NotImplementedError("chat_with_tools is not supported by default.")
+
+    async def achat_with_tools(
+        self,
+        tools: List["BaseTool"],
+        user_msg: Optional[Union[str, ChatMessage]] = None,
+        chat_history: Optional[List[ChatMessage]] = None,
+        verbose: bool = False,
+        allow_parallel_tool_calls: bool = False,
+        **kwargs: Any,
+    ) -> ChatResponse:
+        """Predict and call the tool."""
+        raise NotImplementedError("achat_with_tools is not supported by default.")
+
+    def get_tool_calls_from_response(
+        self,
+        response: "AgentChatResponse",
+        error_on_no_tool_call: bool = True,
+        **kwargs: Any,
+    ) -> List[ToolSelection]:
+        """Predict and call the tool."""
+        raise NotImplementedError(
+            "get_tool_calls_from_response is not supported by default."
+        )
+
+    def predict_and_call(
+        self,
+        tools: List["BaseTool"],
+        user_msg: Optional[Union[str, ChatMessage]] = None,
+        chat_history: Optional[List[ChatMessage]] = None,
+        verbose: bool = False,
+        allow_parallel_tool_calls: bool = False,
+        **kwargs: Any,
+    ) -> "AgentChatResponse":
+        """Predict and call the tool."""
+        from llama_index.core.chat_engine.types import AgentChatResponse
+        from llama_index.core.tools.calling import (
+            call_tool_with_selection,
+        )
+
+        if not self.metadata.is_function_calling_model:
+            return super().predict_and_call(
+                tools,
+                user_msg=user_msg,
+                chat_history=chat_history,
+                verbose=verbose,
+                **kwargs,
+            )
+
+        response = self.chat_with_tools(
+            tools,
+            user_msg,
+            chat_history=chat_history,
+            verbose=verbose,
+            allow_parallel_tool_calls=allow_parallel_tool_calls,
+            **kwargs,
+        )
+        tool_calls = self.get_tool_calls_from_response(response)
+        tool_outputs = [
+            call_tool_with_selection(tool_call, tools, verbose=verbose)
+            for tool_call in tool_calls
+        ]
+        if allow_parallel_tool_calls:
+            output_text = "\n\n".join(
+                [tool_output.content for tool_output in tool_outputs]
+            )
+            return AgentChatResponse(response=output_text, sources=tool_outputs)
+        else:
+            if len(tool_outputs) > 1:
+                raise ValueError("Invalid")
+            return AgentChatResponse(
+                response=tool_outputs[0].content, sources=tool_outputs
+            )
+
+    async def apredict_and_call(
+        self,
+        tools: List["BaseTool"],
+        user_msg: Optional[Union[str, ChatMessage]] = None,
+        chat_history: Optional[List[ChatMessage]] = None,
+        verbose: bool = False,
+        allow_parallel_tool_calls: bool = False,
+        **kwargs: Any,
+    ) -> "AgentChatResponse":
+        """Predict and call the tool."""
+        from llama_index.core.tools.calling import (
+            acall_tool_with_selection,
+        )
+        from llama_index.core.chat_engine.types import AgentChatResponse
+
+        if not self.metadata.is_function_calling_model:
+            return await super().apredict_and_call(
+                tools,
+                user_msg=user_msg,
+                chat_history=chat_history,
+                verbose=verbose,
+                **kwargs,
+            )
+
+        response = await self.achat_with_tools(
+            tools,
+            user_msg,
+            chat_history=chat_history,
+            verbose=verbose,
+            allow_parallel_tool_calls=allow_parallel_tool_calls,
+            **kwargs,
+        )
+        tool_calls = self.get_tool_calls_from_response(response)
+        tool_tasks = [
+            acall_tool_with_selection(tool_call, tools, verbose=verbose)
+            for tool_call in tool_calls
+        ]
+        tool_outputs = await asyncio.gather(*tool_tasks)
+        if allow_parallel_tool_calls:
+            output_text = "\n\n".join(
+                [tool_output.content for tool_output in tool_outputs]
+            )
+            return AgentChatResponse(response=output_text, sources=tool_outputs)
+        else:
+            if len(tool_outputs) > 1:
+                raise ValueError("Invalid")
+            return AgentChatResponse(
+                response=tool_outputs[0].content, sources=tool_outputs
+            )
diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py
index b745c0b009..7f3e8ae623 100644
--- a/llama-index-core/llama_index/core/llms/llm.py
+++ b/llama-index-core/llama_index/core/llms/llm.py
@@ -56,7 +56,6 @@ from llama_index.core.instrumentation.events.llm import (
 import llama_index.core.instrumentation as instrument
 from llama_index.core.base.llms.types import (
     ChatMessage,
-    ChatResponse,
 )
 
 dispatcher = instrument.get_dispatcher(__name__)
@@ -545,43 +544,6 @@ class LLM(BaseLLM):
 
         return stream_tokens
 
-    # -- Tool Calling --
-
-    def chat_with_tools(
-        self,
-        tools: List["BaseTool"],
-        user_msg: Optional[Union[str, ChatMessage]] = None,
-        chat_history: Optional[List[ChatMessage]] = None,
-        verbose: bool = False,
-        allow_parallel_tool_calls: bool = False,
-        **kwargs: Any,
-    ) -> ChatResponse:
-        """Predict and call the tool."""
-        raise NotImplementedError("predict_tool is not supported by default.")
-
-    async def achat_with_tools(
-        self,
-        tools: List["BaseTool"],
-        user_msg: Optional[Union[str, ChatMessage]] = None,
-        chat_history: Optional[List[ChatMessage]] = None,
-        verbose: bool = False,
-        allow_parallel_tool_calls: bool = False,
-        **kwargs: Any,
-    ) -> ChatResponse:
-        """Predict and call the tool."""
-        raise NotImplementedError("predict_tool is not supported by default.")
-
-    def _get_tool_calls_from_response(
-        self,
-        response: "AgentChatResponse",
-        error_on_no_tool_call: bool = True,
-        **kwargs: Any,
-    ) -> List[ToolSelection]:
-        """Predict and call the tool."""
-        raise NotImplementedError(
-            "_get_tool_calls_from_response is not supported by default."
-        )
-
     def predict_and_call(
         self,
         tools: List["BaseTool"],
@@ -590,7 +552,12 @@ class LLM(BaseLLM):
         verbose: bool = False,
         **kwargs: Any,
     ) -> "AgentChatResponse":
-        """Predict and call the tool."""
+        """Predict and call the tool.
+
+        By default uses a ReAct agent to do tool calling (through text prompting),
+        but function calling LLMs will implement this differently.
+
+        """
         from llama_index.core.agent.react import ReActAgentWorker
         from llama_index.core.agent.types import Task
         from llama_index.core.memory import ChatMemoryBuffer
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
index 35168eb559..047da4d912 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
@@ -26,13 +26,13 @@ from llama_index.core.base.llms.generic_utils import (
     get_from_param_or_env,
     stream_chat_to_completion_decorator,
 )
-from llama_index.core.llms.llm import LLM, ToolSelection
+from llama_index.core.llms.llm import ToolSelection
 from llama_index.core.types import BaseOutputParser, PydanticProgramMode
+from llama_index.core.llms.function_calling import FunctionCallingLLM
 from llama_index.llms.mistralai.utils import (
     is_mistralai_function_calling_model,
     mistralai_modelname_to_contextsize,
 )
-import asyncio
 
 from mistralai.async_client import MistralAsyncClient
 from mistralai.client import MistralClient
@@ -70,7 +70,7 @@ def force_single_tool_call(response: ChatResponse) -> None:
         response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
 
 
-class MistralAI(LLM):
+class MistralAI(FunctionCallingLLM):
     """MistralAI LLM.
 
     Examples:
@@ -283,56 +283,6 @@ class MistralAI(LLM):
         stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat)
         return stream_complete_fn(prompt, **kwargs)
 
-    def predict_and_call(
-        self,
-        tools: List["BaseTool"],
-        user_msg: Optional[Union[str, ChatMessage]] = None,
-        chat_history: Optional[List[ChatMessage]] = None,
-        verbose: bool = False,
-        allow_parallel_tool_calls: bool = False,
-        **kwargs: Any,
-    ) -> "AgentChatResponse":
-        from llama_index.core.chat_engine.types import AgentChatResponse
-        from llama_index.core.tools.calling import (
-            call_tool_with_selection,
-        )
-
-        if not self.metadata.is_function_calling_model:
-            return super().predict_and_call(
-                tools,
-                user_msg=user_msg,
-                chat_history=chat_history,
-                verbose=verbose,
-                **kwargs,
-            )
-
-        response = self.chat_with_tools(
-            tools,
-            user_msg,
-            chat_history=chat_history,
-            verbose=verbose,
-            allow_parallel_tool_calls=allow_parallel_tool_calls,
-            **kwargs,
-        )
-        tool_calls = self._get_tool_calls_from_response(response)
-        tool_outputs = [
-            call_tool_with_selection(tool_call, tools, verbose=verbose)
-            for tool_call in tool_calls
-        ]
-        if allow_parallel_tool_calls:
-            output_text = "\n\n".join(
-                [tool_output.content for tool_output in tool_outputs]
-            )
-            return AgentChatResponse(response=output_text, sources=tool_outputs)
-        else:
-            if len(tool_outputs) > 1:
-                raise ValueError(
-                    "Can't have multiple tool outputs if `allow_parallel_tool_calls` is True."
-                )
-            return AgentChatResponse(
-                response=tool_outputs[0].content, sources=tool_outputs
-            )
-
     @llm_chat_callback()
     async def achat(
         self, messages: Sequence[ChatMessage], **kwargs: Any
@@ -395,57 +345,6 @@ class MistralAI(LLM):
         astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
         return await astream_complete_fn(prompt, **kwargs)
 
-    async def apredict_and_call(
-        self,
-        tools: List["BaseTool"],
-        user_msg: Optional[Union[str, ChatMessage]] = None,
-        chat_history: Optional[List[ChatMessage]] = None,
-        verbose: bool = False,
-        allow_parallel_tool_calls: bool = False,
-        **kwargs: Any,
-    ) -> "AgentChatResponse":
-        from llama_index.core.tools.calling import (
-            acall_tool_with_selection,
-        )
-        from llama_index.core.chat_engine.types import AgentChatResponse
-
-        if not self.metadata.is_function_calling_model:
-            return await super().apredict_and_call(
-                tools,
-                user_msg=user_msg,
-                chat_history=chat_history,
-                verbose=verbose,
-                **kwargs,
-            )
-
-        response = await self.achat_with_tools(
-            tools,
-            user_msg,
-            chat_history=chat_history,
-            verbose=verbose,
-            allow_parallel_tool_calls=allow_parallel_tool_calls,
-            **kwargs,
-        )
-        tool_calls = self._get_tool_calls_from_response(response)
-        tool_tasks = [
-            acall_tool_with_selection(tool_call, tools, verbose=verbose)
-            for tool_call in tool_calls
-        ]
-        tool_outputs = await asyncio.gather(*tool_tasks)
-        if allow_parallel_tool_calls:
-            output_text = "\n\n".join(
-                [tool_output.content for tool_output in tool_outputs]
-            )
-            return AgentChatResponse(response=output_text, sources=tool_outputs)
-        else:
-            if len(tool_outputs) > 1:
-                raise ValueError(
-                    "Can't have multiple tool outputs if `allow_parallel_tool_calls` is True."
-                )
-            return AgentChatResponse(
-                response=tool_outputs[0].content, sources=tool_outputs
-            )
-
     def chat_with_tools(
         self,
         tools: List["BaseTool"],
@@ -504,7 +403,7 @@ class MistralAI(LLM):
             force_single_tool_call(response)
         return response
 
-    def _get_tool_calls_from_response(
+    def get_tool_calls_from_response(
         self,
         response: "AgentChatResponse",
         error_on_no_tool_call: bool = True,
-- 
GitLab