diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index e78f52b006ea9f5ab2c374924c34f72fa5f7304e..04a201ce128f2d5da4b73fe26a2a2db0405014a5 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -7,8 +7,10 @@ from typing import ( Optional, Protocol, Sequence, + Union, get_args, runtime_checkable, + TYPE_CHECKING, ) from llama_index.core.base.llms.types import ( @@ -55,6 +57,10 @@ import llama_index.core.instrumentation as instrument dispatcher = instrument.get_dispatcher(__name__) +if TYPE_CHECKING: + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.types import BaseTool + # NOTE: These two protocols are needed to appease mypy @runtime_checkable @@ -165,6 +171,8 @@ class LLM(BaseLLM): exclude=True, ) + # -- Pydantic Configs -- + @validator("messages_to_prompt", pre=True) def set_messages_to_prompt( cls, messages_to_prompt: Optional[MessagesToPromptType] @@ -185,6 +193,8 @@ class LLM(BaseLLM): values["messages_to_prompt"] = generic_messages_to_prompt return values + # -- Utils -- + def _log_template_data( self, prompt: BasePromptTemplate, **prompt_args: Any ) -> None: @@ -223,6 +233,47 @@ class LLM(BaseLLM): messages = self.output_parser.format_messages(messages) return self._extend_messages(messages) + def _parse_output(self, output: str) -> str: + if self.output_parser is not None: + return str(self.output_parser.parse(output)) + + return output + + def _extend_prompt( + self, + formatted_prompt: str, + ) -> str: + """Add system and query wrapper prompts to base prompt.""" + extended_prompt = formatted_prompt + + if self.system_prompt: + extended_prompt = self.system_prompt + "\n\n" + extended_prompt + + if self.query_wrapper_prompt: + extended_prompt = self.query_wrapper_prompt.format( + query_str=extended_prompt + ) + + return extended_prompt + + def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: + """Add system prompt to chat message list.""" + if self.system_prompt: + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt), + *messages, + ] + return messages + + def _as_query_component(self, **kwargs: Any) -> QueryComponent: + """Return query component.""" + if self.metadata.is_chat_model: + return LLMChatComponent(llm=self, **kwargs) + else: + return LLMCompleteComponent(llm=self, **kwargs) + + # -- Structured outputs -- + def structured_predict( self, output_cls: BaseModel, @@ -313,11 +364,7 @@ class LLM(BaseLLM): return await program.acall(**prompt_args) - def _parse_output(self, output: str) -> str: - if self.output_parser is not None: - return str(self.output_parser.parse(output)) - - return output + # -- Prompt Chaining -- @dispatcher.span def predict( @@ -485,38 +532,99 @@ class LLM(BaseLLM): return stream_tokens - def _extend_prompt( + # -- Tool Calling -- + + def predict_and_call( self, - formatted_prompt: str, - ) -> str: - """Add system and query wrapper prompts to base prompt.""" - extended_prompt = formatted_prompt + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + """Predict and call the tool.""" + from llama_index.core.agent.react import ReActAgentWorker + from llama_index.core.agent.types import Task + from llama_index.core.memory import ChatMemoryBuffer + + worker = ReActAgentWorker( + tools, + llm=self, + callback_manager=self.callback_manager, + verbose=verbose, + **kwargs, + ) - if self.system_prompt: - extended_prompt = self.system_prompt + "\n\n" + extended_prompt + if isinstance(user_msg, ChatMessage): + user_msg = user_msg.content - if self.query_wrapper_prompt: - extended_prompt = self.query_wrapper_prompt.format( - query_str=extended_prompt + task = Task( + input=user_msg, + memory=ChatMemoryBuffer.from_defaults(chat_history=chat_history), + extra_state={}, + callback_manager=self.callback_manager, + ) + step = worker.initialize_step(task) + + try: + output = worker.run_step(step, task).output + + # react agent worker inserts a "Observation: " prefix to the response + if output.response and output.response.startswith("Observation: "): + output.response = output.response.replace("Observation: ", "") + except Exception as e: + output = AgentChatResponse( + response="An error occurred while running the tool: " + str(e), + sources=[], ) - return extended_prompt + return output - def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: - """Add system prompt to chat message list.""" - if self.system_prompt: - messages = [ - ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt), - *messages, - ] - return messages + async def apredict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + """Predict and call the tool.""" + from llama_index.core.agent.react import ReActAgentWorker + from llama_index.core.agent.types import Task + from llama_index.core.memory import ChatMemoryBuffer + + worker = ReActAgentWorker( + tools, + llm=self, + callback_manager=self.callback_manager, + verbose=verbose, + **kwargs, + ) - def _as_query_component(self, **kwargs: Any) -> QueryComponent: - """Return query component.""" - if self.metadata.is_chat_model: - return LLMChatComponent(llm=self, **kwargs) - else: - return LLMCompleteComponent(llm=self, **kwargs) + if isinstance(user_msg, ChatMessage): + user_msg = user_msg.content + + task = Task( + input=user_msg, + memory=ChatMemoryBuffer.from_defaults(chat_history=chat_history), + extra_state={}, + callback_manager=self.callback_manager, + ) + step = worker.initialize_step(task) + + try: + output = await worker.arun_step(step, task).output + + # react agent worker inserts a "Observation: " prefix to the response + if output.response and output.response.startswith("Observation: "): + output.response = output.response.replace("Observation: ", "") + except Exception as e: + output = AgentChatResponse( + response="An error occurred while running the tool: " + str(e), + sources=[], + ) + + return output class BaseLLMComponent(QueryComponent): diff --git a/llama-index-core/llama_index/core/tools/calling.py b/llama-index-core/llama_index/core/tools/calling.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbe600674c96b51ba9ca7a8b54b033119cffb81 --- /dev/null +++ b/llama-index-core/llama_index/core/tools/calling.py @@ -0,0 +1,42 @@ +from llama_index.core.tools.types import BaseTool, ToolOutput, adapt_to_async_tool + + +def call_tool(tool: BaseTool, arguments: dict) -> ToolOutput: + """Call a tool with arguments.""" + try: + if ( + len(tool.metadata.get_parameters_dict()["properties"]) == 1 + and len(arguments) == 1 + ): + single_arg = arguments[next(iter(arguments))] + return tool(single_arg) + else: + return tool(**arguments) + except Exception as e: + return ToolOutput( + content="Encountered error: " + str(e), + tool_name=tool.metadata.name, + raw_input=arguments, + raw_output=str(e), + ) + + +async def acall_tool(tool: BaseTool, arguments: dict) -> ToolOutput: + """Call a tool with arguments asynchronously.""" + async_tool = adapt_to_async_tool(tool) + try: + if ( + len(tool.metadata.get_parameters_dict()["properties"]) == 1 + and len(arguments) == 1 + ): + single_arg = arguments[next(iter(arguments))] + return await async_tool.acall(single_arg) + else: + return await async_tool.acall(**arguments) + except Exception as e: + return ToolOutput( + content="Encountered error: " + str(e), + tool_name=tool.metadata.name, + raw_input=arguments, + raw_output=str(e), + ) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 784f574d73f4e491bdf6f6e4b39d43ec69be469b..7484494b4556e9d2c135452805663ec933be21f7 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, Dict, Optional, Sequence +import json +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, TYPE_CHECKING -# from mistralai.models.chat_completion import ChatMessage from llama_index.core.base.llms.types import ( ChatMessage, ChatResponse, @@ -29,11 +29,17 @@ from llama_index.core.base.llms.generic_utils import ( from llama_index.core.llms.llm import LLM from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.mistralai.utils import ( + is_mistralai_function_calling_model, mistralai_modelname_to_contextsize, ) from mistralai.async_client import MistralAsyncClient from mistralai.client import MistralClient +from mistralai.models.chat_completion import ToolCall + +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool + from llama_index.core.chat_engine.types import AgentChatResponse DEFAULT_MISTRALAI_MODEL = "mistral-tiny" DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai" @@ -168,6 +174,7 @@ class MistralAI(LLM): model_name=self.model, safe_mode=self.safe_mode, random_seed=self.random_seed, + is_function_calling_model=is_mistralai_function_calling_model(self.model), ) @property @@ -200,9 +207,16 @@ class MistralAI(LLM): ] all_kwargs = self._get_all_kwargs(**kwargs) response = self._client.chat(messages=messages, **all_kwargs) + + tool_calls = response.choices[0].message.tool_calls + return ChatResponse( message=ChatMessage( - role=MessageRole.ASSISTANT, content=response.choices[0].message.content + role=MessageRole.ASSISTANT, + content=response.choices[0].message.content, + additional_kwargs={"tool_calls": tool_calls} + if tool_calls is not None + else {}, ), raw=dict(response), ) @@ -252,6 +266,86 @@ class MistralAI(LLM): stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) return stream_complete_fn(prompt, **kwargs) + def _get_tool_call( + self, + response: ChatResponse, + ) -> ToolCall: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + + # TODO: support more than one tool call? + tool_call = tool_calls[0] + if not isinstance(tool_call, ToolCall): + raise ValueError("Invalid tool_call object") + + if tool_call.type != "function": + raise ValueError("Invalid tool type. Unsupported by Mistralai.") + + return tool_call + + def _call_tool( + self, + tool_call: ToolCall, + tools_by_name: Dict[str, "BaseTool"], + verbose: bool = False, + ) -> "AgentChatResponse": + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.calling import call_tool + + arguments_str = tool_call.function.arguments + name = tool_call.function.name + if verbose: + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + argument_dict = json.loads(arguments_str) + + tool_output = call_tool(tool, argument_dict) + + return AgentChatResponse(response=tool_output.content, sources=[tool_output]) + + def predict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + if not self.metadata.is_function_calling_model: + return super().predict_and_call( + tools, + user_msg=user_msg, + chat_history=chat_history, + verbose=verbose, + **kwargs, + ) + + # misralai uses the same openai tool format + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + tools_by_name = {tool.metadata.name: tool for tool in tools} + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = self.chat( + messages, + tools=tool_specs, + **kwargs, + ) + + tool_call = self._get_tool_call(response) + + return self._call_tool(tool_call, tools_by_name, verbose=verbose) + @llm_chat_callback() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any @@ -315,3 +409,56 @@ class MistralAI(LLM): ) -> CompletionResponseAsyncGen: astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) return await astream_complete_fn(prompt, **kwargs) + + async def _acall_tool( + self, + tool_call: ToolCall, + tools_by_name: Dict[str, "BaseTool"], + verbose: bool = False, + ) -> "AgentChatResponse": + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.calling import acall_tool + + arguments_str = tool_call.function.arguments + name = tool_call.function.name + if verbose: + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + argument_dict = json.loads(arguments_str) + + tool_output = await acall_tool(tool, argument_dict) + + return AgentChatResponse(response=tool_output.content, sources=[tool_output]) + + async def apredict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + if not self.metadata.is_function_calling_model: + return await super().apredict_and_call(user_msg, tools, verbose, **kwargs) + + # misralai uses the same openai tool format + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + tools_by_name = {tool.metadata.name: tool for tool in tools} + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = await self.achat( + messages, + tools=tool_specs, + **kwargs, + ) + + tool_call = self._get_tool_call(response) + + return await self._acall_tool(tool_call, tools_by_name, verbose=verbose) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py index 5598066d7cb5a388a39d563b197934dca8ccc303..613c3bcadd472947f48a5a1ccf3520415633e940 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py @@ -12,6 +12,8 @@ MISTRALAI_MODELS: Dict[str, int] = { "mistral-large-latest": 32000, } +MISTRALAI_FUNCTION_CALLING_MODELS = ("mistral-large-latest",) + def mistralai_modelname_to_contextsize(modelname: str) -> int: if modelname not in MISTRALAI_MODELS: @@ -21,3 +23,7 @@ def mistralai_modelname_to_contextsize(modelname: str) -> int: ) return MISTRALAI_MODELS[modelname] + + +def is_mistralai_function_calling_model(modelname: str) -> bool: + return modelname in MISTRALAI_FUNCTION_CALLING_MODELS diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py index 58bb23f1a6f2cc7bdca0331cda87732a26527388..9f0b43b8f32fe4ced557db143e99b53e3223839c 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py @@ -7,10 +7,14 @@ from typing import ( Optional, Protocol, Sequence, + Union, cast, + get_args, runtime_checkable, + TYPE_CHECKING, ) +import json import httpx import tiktoken from llama_index.core.base.llms.types import ( @@ -46,6 +50,7 @@ from llama_index.core.base.llms.generic_utils import ( from llama_index.core.llms.llm import LLM from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.openai.utils import ( + OpenAIToolCall, create_retry_decorator, from_openai_message, from_openai_token_logprobs, @@ -65,6 +70,10 @@ from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaToolCall, ) +if TYPE_CHECKING: + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.types import BaseTool + DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" llm_retry_decorator = create_retry_decorator( @@ -544,6 +553,85 @@ class OpenAI(LLM): "total_tokens": usage.get("total_tokens", 0), } + def _get_tool_call( + self, + response: ChatResponse, + ) -> OpenAIToolCall: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + + # TODO: support more than one tool call? + tool_call = tool_calls[0] + if not isinstance(tool_call, get_args(OpenAIToolCall)): + raise ValueError("Invalid tool_call object") + + if tool_call.type != "function": + raise ValueError("Invalid tool type. Unsupported by OpenAI") + + return tool_call + + def _call_tool( + self, + tool_call: OpenAIToolCall, + tools_by_name: Dict[str, "BaseTool"], + verbose: bool = False, + ) -> "AgentChatResponse": + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.calling import call_tool + + arguments_str = tool_call.function.arguments + name = tool_call.function.name + if verbose: + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + argument_dict = json.loads(arguments_str) + + tool_output = call_tool(tool, argument_dict) + + return AgentChatResponse(response=tool_output.content, sources=[tool_output]) + + def predict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + if not self.metadata.is_function_calling_model: + return super().predict_and_call( + tools, + user_msg=user_msg, + chat_history=chat_history, + verbose=verbose, + **kwargs, + ) + + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + tools_by_name = {tool.metadata.name: tool for tool in tools} + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = self.chat( + messages, + tools=tool_specs, + **kwargs, + ) + + tool_call = self._get_tool_call(response) + + return self._call_tool(tool_call, tools_by_name, verbose=verbose) + # ===== Async Endpoints ===== @llm_chat_callback() async def achat( @@ -731,3 +819,55 @@ class OpenAI(LLM): ) return gen() + + async def _acall_tool( + self, + tool_call: OpenAIToolCall, + tools_by_name: Dict[str, "BaseTool"], + verbose: bool = False, + ) -> "AgentChatResponse": + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.calling import acall_tool + + arguments_str = tool_call.function.arguments + name = tool_call.function.name + if verbose: + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + argument_dict = json.loads(arguments_str) + + tool_output = await acall_tool(tool, argument_dict) + + return AgentChatResponse(response=tool_output.content, sources=[tool_output]) + + async def apredict_and_call( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> "AgentChatResponse": + if not self.metadata.is_function_calling_model: + return await super().apredict_and_call(user_msg, tools, verbose, **kwargs) + + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + tools_by_name = {tool.metadata.name: tool for tool in tools} + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = await self.achat( + messages, + tools=tool_specs, + **kwargs, + ) + + tool_call = self._get_tool_call(response) + + return await self._acall_tool(tool_call, tools_by_name, verbose=verbose)