From a088e391160ee947787198a06096ac9fdaaa675f Mon Sep 17 00:00:00 2001 From: Jerry Liu <jerryjliu98@gmail.com> Date: Mon, 25 Mar 2024 13:00:20 -0700 Subject: [PATCH] add function calling LLM base class (#12222) --- .../core/agent/function_calling/step.py | 4 +- .../llama_index/core/llms/function_calling.py | 165 ++++++++++++++++++ llama-index-core/llama_index/core/llms/llm.py | 45 +---- .../llama_index/llms/mistralai/base.py | 109 +----------- 4 files changed, 177 insertions(+), 146 deletions(-) create mode 100644 llama-index-core/llama_index/core/llms/function_calling.py diff --git a/llama-index-core/llama_index/core/agent/function_calling/step.py b/llama-index-core/llama_index/core/agent/function_calling/step.py index f6d1444b42..69726c7631 100644 --- a/llama-index-core/llama_index/core/agent/function_calling/step.py +++ b/llama-index-core/llama_index/core/agent/function_calling/step.py @@ -246,7 +246,7 @@ class FunctionCallingAgentWorker(BaseAgentWorker): verbose=self._verbose, allow_parallel_tool_calls=self.allow_parallel_tool_calls, ) - tool_calls = self._llm._get_tool_calls_from_response( + tool_calls = self._llm.get_tool_calls_from_response( response, error_on_no_tool_call=False ) if not self.allow_parallel_tool_calls and len(tool_calls) > 1: @@ -314,7 +314,7 @@ class FunctionCallingAgentWorker(BaseAgentWorker): verbose=self._verbose, allow_parallel_tool_calls=self.allow_parallel_tool_calls, ) - tool_calls = self._llm._get_tool_calls_from_response( + tool_calls = self._llm.get_tool_calls_from_response( response, error_on_no_tool_call=False ) if not self.allow_parallel_tool_calls and len(tool_calls) > 1: diff --git a/llama-index-core/llama_index/core/llms/function_calling.py b/llama-index-core/llama_index/core/llms/function_calling.py new file mode 100644 index 0000000000..c7616c4757 --- /dev/null +++ b/llama-index-core/llama_index/core/llms/function_calling.py @@ -0,0 +1,165 @@ +from typing import ( + Any, + List, + Optional, + Union, + TYPE_CHECKING, +) +import asyncio + +from llama_index.core.base.llms.types import ( + ChatMessage, +) +from llama_index.core.llms.llm import LLM + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, +) +from llama_index.core.llms.llm import ToolSelection + +if TYPE_CHECKING: + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.types import BaseTool + + +class FunctionCallingLLM(LLM): + """ + Function calling LLMs are LLMs that support function calling. + They support an expanded range of capabilities. + + """ + + def chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + raise NotImplementedError("chat_with_tools is not supported by default.") + + async def achat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + raise NotImplementedError("achat_with_tools is not supported by default.") + + def get_tool_calls_from_response( + self, + response: "AgentChatResponse", + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + """Predict and call the tool.""" + raise NotImplementedError( + "get_tool_calls_from_response is not supported by default." + ) + + def predict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + """Predict and call the tool.""" + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.calling import ( + call_tool_with_selection, + ) + + if not self.metadata.is_function_calling_model: + return super().predict_and_call( + tools, + user_msg=user_msg, + chat_history=chat_history, + verbose=verbose, + **kwargs, + ) + + response = self.chat_with_tools( + tools, + user_msg, + chat_history=chat_history, + verbose=verbose, + allow_parallel_tool_calls=allow_parallel_tool_calls, + **kwargs, + ) + tool_calls = self.get_tool_calls_from_response(response) + tool_outputs = [ + call_tool_with_selection(tool_call, tools, verbose=verbose) + for tool_call in tool_calls + ] + if allow_parallel_tool_calls: + output_text = "\n\n".join( + [tool_output.content for tool_output in tool_outputs] + ) + return AgentChatResponse(response=output_text, sources=tool_outputs) + else: + if len(tool_outputs) > 1: + raise ValueError("Invalid") + return AgentChatResponse( + response=tool_outputs[0].content, sources=tool_outputs + ) + + async def apredict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + """Predict and call the tool.""" + from llama_index.core.tools.calling import ( + acall_tool_with_selection, + ) + from llama_index.core.chat_engine.types import AgentChatResponse + + if not self.metadata.is_function_calling_model: + return await super().apredict_and_call( + tools, + user_msg=user_msg, + chat_history=chat_history, + verbose=verbose, + **kwargs, + ) + + response = await self.achat_with_tools( + tools, + user_msg, + chat_history=chat_history, + verbose=verbose, + allow_parallel_tool_calls=allow_parallel_tool_calls, + **kwargs, + ) + tool_calls = self.get_tool_calls_from_response(response) + tool_tasks = [ + acall_tool_with_selection(tool_call, tools, verbose=verbose) + for tool_call in tool_calls + ] + tool_outputs = await asyncio.gather(*tool_tasks) + if allow_parallel_tool_calls: + output_text = "\n\n".join( + [tool_output.content for tool_output in tool_outputs] + ) + return AgentChatResponse(response=output_text, sources=tool_outputs) + else: + if len(tool_outputs) > 1: + raise ValueError("Invalid") + return AgentChatResponse( + response=tool_outputs[0].content, sources=tool_outputs + ) diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index b745c0b009..7f3e8ae623 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -56,7 +56,6 @@ from llama_index.core.instrumentation.events.llm import ( import llama_index.core.instrumentation as instrument from llama_index.core.base.llms.types import ( ChatMessage, - ChatResponse, ) dispatcher = instrument.get_dispatcher(__name__) @@ -545,43 +544,6 @@ class LLM(BaseLLM): return stream_tokens - # -- Tool Calling -- - - def chat_with_tools( - self, - tools: List["BaseTool"], - user_msg: Optional[Union[str, ChatMessage]] = None, - chat_history: Optional[List[ChatMessage]] = None, - verbose: bool = False, - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> ChatResponse: - """Predict and call the tool.""" - raise NotImplementedError("predict_tool is not supported by default.") - - async def achat_with_tools( - self, - tools: List["BaseTool"], - user_msg: Optional[Union[str, ChatMessage]] = None, - chat_history: Optional[List[ChatMessage]] = None, - verbose: bool = False, - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> ChatResponse: - """Predict and call the tool.""" - raise NotImplementedError("predict_tool is not supported by default.") - - def _get_tool_calls_from_response( - self, - response: "AgentChatResponse", - error_on_no_tool_call: bool = True, - **kwargs: Any, - ) -> List[ToolSelection]: - """Predict and call the tool.""" - raise NotImplementedError( - "_get_tool_calls_from_response is not supported by default." - ) - def predict_and_call( self, tools: List["BaseTool"], @@ -590,7 +552,12 @@ class LLM(BaseLLM): verbose: bool = False, **kwargs: Any, ) -> "AgentChatResponse": - """Predict and call the tool.""" + """Predict and call the tool. + + By default uses a ReAct agent to do tool calling (through text prompting), + but function calling LLMs will implement this differently. + + """ from llama_index.core.agent.react import ReActAgentWorker from llama_index.core.agent.types import Task from llama_index.core.memory import ChatMemoryBuffer diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 35168eb559..047da4d912 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -26,13 +26,13 @@ from llama_index.core.base.llms.generic_utils import ( get_from_param_or_env, stream_chat_to_completion_decorator, ) -from llama_index.core.llms.llm import LLM, ToolSelection +from llama_index.core.llms.llm import ToolSelection from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.llms.mistralai.utils import ( is_mistralai_function_calling_model, mistralai_modelname_to_contextsize, ) -import asyncio from mistralai.async_client import MistralAsyncClient from mistralai.client import MistralClient @@ -70,7 +70,7 @@ def force_single_tool_call(response: ChatResponse) -> None: response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] -class MistralAI(LLM): +class MistralAI(FunctionCallingLLM): """MistralAI LLM. Examples: @@ -283,56 +283,6 @@ class MistralAI(LLM): stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) return stream_complete_fn(prompt, **kwargs) - def predict_and_call( - self, - tools: List["BaseTool"], - user_msg: Optional[Union[str, ChatMessage]] = None, - chat_history: Optional[List[ChatMessage]] = None, - verbose: bool = False, - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> "AgentChatResponse": - from llama_index.core.chat_engine.types import AgentChatResponse - from llama_index.core.tools.calling import ( - call_tool_with_selection, - ) - - if not self.metadata.is_function_calling_model: - return super().predict_and_call( - tools, - user_msg=user_msg, - chat_history=chat_history, - verbose=verbose, - **kwargs, - ) - - response = self.chat_with_tools( - tools, - user_msg, - chat_history=chat_history, - verbose=verbose, - allow_parallel_tool_calls=allow_parallel_tool_calls, - **kwargs, - ) - tool_calls = self._get_tool_calls_from_response(response) - tool_outputs = [ - call_tool_with_selection(tool_call, tools, verbose=verbose) - for tool_call in tool_calls - ] - if allow_parallel_tool_calls: - output_text = "\n\n".join( - [tool_output.content for tool_output in tool_outputs] - ) - return AgentChatResponse(response=output_text, sources=tool_outputs) - else: - if len(tool_outputs) > 1: - raise ValueError( - "Can't have multiple tool outputs if `allow_parallel_tool_calls` is True." - ) - return AgentChatResponse( - response=tool_outputs[0].content, sources=tool_outputs - ) - @llm_chat_callback() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any @@ -395,57 +345,6 @@ class MistralAI(LLM): astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) return await astream_complete_fn(prompt, **kwargs) - async def apredict_and_call( - self, - tools: List["BaseTool"], - user_msg: Optional[Union[str, ChatMessage]] = None, - chat_history: Optional[List[ChatMessage]] = None, - verbose: bool = False, - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> "AgentChatResponse": - from llama_index.core.tools.calling import ( - acall_tool_with_selection, - ) - from llama_index.core.chat_engine.types import AgentChatResponse - - if not self.metadata.is_function_calling_model: - return await super().apredict_and_call( - tools, - user_msg=user_msg, - chat_history=chat_history, - verbose=verbose, - **kwargs, - ) - - response = await self.achat_with_tools( - tools, - user_msg, - chat_history=chat_history, - verbose=verbose, - allow_parallel_tool_calls=allow_parallel_tool_calls, - **kwargs, - ) - tool_calls = self._get_tool_calls_from_response(response) - tool_tasks = [ - acall_tool_with_selection(tool_call, tools, verbose=verbose) - for tool_call in tool_calls - ] - tool_outputs = await asyncio.gather(*tool_tasks) - if allow_parallel_tool_calls: - output_text = "\n\n".join( - [tool_output.content for tool_output in tool_outputs] - ) - return AgentChatResponse(response=output_text, sources=tool_outputs) - else: - if len(tool_outputs) > 1: - raise ValueError( - "Can't have multiple tool outputs if `allow_parallel_tool_calls` is True." - ) - return AgentChatResponse( - response=tool_outputs[0].content, sources=tool_outputs - ) - def chat_with_tools( self, tools: List["BaseTool"], @@ -504,7 +403,7 @@ class MistralAI(LLM): force_single_tool_call(response) return response - def _get_tool_calls_from_response( + def get_tool_calls_from_response( self, response: "AgentChatResponse", error_on_no_tool_call: bool = True, -- GitLab