Skip to content
Snippets Groups Projects
Unverified Commit f438a5c7 authored by Logan's avatar Logan Committed by GitHub
Browse files

llm-level tool calling (#12188)

parent 4636776a
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,10 @@ from typing import (
Optional,
Protocol,
Sequence,
Union,
get_args,
runtime_checkable,
TYPE_CHECKING,
)
from llama_index.core.base.llms.types import (
......@@ -55,6 +57,10 @@ import llama_index.core.instrumentation as instrument
dispatcher = instrument.get_dispatcher(__name__)
if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.types import BaseTool
# NOTE: These two protocols are needed to appease mypy
@runtime_checkable
......@@ -165,6 +171,8 @@ class LLM(BaseLLM):
exclude=True,
)
# -- Pydantic Configs --
@validator("messages_to_prompt", pre=True)
def set_messages_to_prompt(
cls, messages_to_prompt: Optional[MessagesToPromptType]
......@@ -185,6 +193,8 @@ class LLM(BaseLLM):
values["messages_to_prompt"] = generic_messages_to_prompt
return values
# -- Utils --
def _log_template_data(
self, prompt: BasePromptTemplate, **prompt_args: Any
) -> None:
......@@ -223,6 +233,47 @@ class LLM(BaseLLM):
messages = self.output_parser.format_messages(messages)
return self._extend_messages(messages)
def _parse_output(self, output: str) -> str:
if self.output_parser is not None:
return str(self.output_parser.parse(output))
return output
def _extend_prompt(
self,
formatted_prompt: str,
) -> str:
"""Add system and query wrapper prompts to base prompt."""
extended_prompt = formatted_prompt
if self.system_prompt:
extended_prompt = self.system_prompt + "\n\n" + extended_prompt
if self.query_wrapper_prompt:
extended_prompt = self.query_wrapper_prompt.format(
query_str=extended_prompt
)
return extended_prompt
def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
"""Add system prompt to chat message list."""
if self.system_prompt:
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
*messages,
]
return messages
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return query component."""
if self.metadata.is_chat_model:
return LLMChatComponent(llm=self, **kwargs)
else:
return LLMCompleteComponent(llm=self, **kwargs)
# -- Structured outputs --
def structured_predict(
self,
output_cls: BaseModel,
......@@ -313,11 +364,7 @@ class LLM(BaseLLM):
return await program.acall(**prompt_args)
def _parse_output(self, output: str) -> str:
if self.output_parser is not None:
return str(self.output_parser.parse(output))
return output
# -- Prompt Chaining --
@dispatcher.span
def predict(
......@@ -485,38 +532,99 @@ class LLM(BaseLLM):
return stream_tokens
def _extend_prompt(
# -- Tool Calling --
def predict_and_call(
self,
formatted_prompt: str,
) -> str:
"""Add system and query wrapper prompts to base prompt."""
extended_prompt = formatted_prompt
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
from llama_index.core.agent.react import ReActAgentWorker
from llama_index.core.agent.types import Task
from llama_index.core.memory import ChatMemoryBuffer
worker = ReActAgentWorker(
tools,
llm=self,
callback_manager=self.callback_manager,
verbose=verbose,
**kwargs,
)
if self.system_prompt:
extended_prompt = self.system_prompt + "\n\n" + extended_prompt
if isinstance(user_msg, ChatMessage):
user_msg = user_msg.content
if self.query_wrapper_prompt:
extended_prompt = self.query_wrapper_prompt.format(
query_str=extended_prompt
task = Task(
input=user_msg,
memory=ChatMemoryBuffer.from_defaults(chat_history=chat_history),
extra_state={},
callback_manager=self.callback_manager,
)
step = worker.initialize_step(task)
try:
output = worker.run_step(step, task).output
# react agent worker inserts a "Observation: " prefix to the response
if output.response and output.response.startswith("Observation: "):
output.response = output.response.replace("Observation: ", "")
except Exception as e:
output = AgentChatResponse(
response="An error occurred while running the tool: " + str(e),
sources=[],
)
return extended_prompt
return output
def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
"""Add system prompt to chat message list."""
if self.system_prompt:
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
*messages,
]
return messages
async def apredict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
from llama_index.core.agent.react import ReActAgentWorker
from llama_index.core.agent.types import Task
from llama_index.core.memory import ChatMemoryBuffer
worker = ReActAgentWorker(
tools,
llm=self,
callback_manager=self.callback_manager,
verbose=verbose,
**kwargs,
)
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return query component."""
if self.metadata.is_chat_model:
return LLMChatComponent(llm=self, **kwargs)
else:
return LLMCompleteComponent(llm=self, **kwargs)
if isinstance(user_msg, ChatMessage):
user_msg = user_msg.content
task = Task(
input=user_msg,
memory=ChatMemoryBuffer.from_defaults(chat_history=chat_history),
extra_state={},
callback_manager=self.callback_manager,
)
step = worker.initialize_step(task)
try:
output = await worker.arun_step(step, task).output
# react agent worker inserts a "Observation: " prefix to the response
if output.response and output.response.startswith("Observation: "):
output.response = output.response.replace("Observation: ", "")
except Exception as e:
output = AgentChatResponse(
response="An error occurred while running the tool: " + str(e),
sources=[],
)
return output
class BaseLLMComponent(QueryComponent):
......
from llama_index.core.tools.types import BaseTool, ToolOutput, adapt_to_async_tool
def call_tool(tool: BaseTool, arguments: dict) -> ToolOutput:
"""Call a tool with arguments."""
try:
if (
len(tool.metadata.get_parameters_dict()["properties"]) == 1
and len(arguments) == 1
):
single_arg = arguments[next(iter(arguments))]
return tool(single_arg)
else:
return tool(**arguments)
except Exception as e:
return ToolOutput(
content="Encountered error: " + str(e),
tool_name=tool.metadata.name,
raw_input=arguments,
raw_output=str(e),
)
async def acall_tool(tool: BaseTool, arguments: dict) -> ToolOutput:
"""Call a tool with arguments asynchronously."""
async_tool = adapt_to_async_tool(tool)
try:
if (
len(tool.metadata.get_parameters_dict()["properties"]) == 1
and len(arguments) == 1
):
single_arg = arguments[next(iter(arguments))]
return await async_tool.acall(single_arg)
else:
return await async_tool.acall(**arguments)
except Exception as e:
return ToolOutput(
content="Encountered error: " + str(e),
tool_name=tool.metadata.name,
raw_input=arguments,
raw_output=str(e),
)
from typing import Any, Callable, Dict, Optional, Sequence
import json
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, TYPE_CHECKING
# from mistralai.models.chat_completion import ChatMessage
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
......@@ -29,11 +29,17 @@ from llama_index.core.base.llms.generic_utils import (
from llama_index.core.llms.llm import LLM
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.mistralai.utils import (
is_mistralai_function_calling_model,
mistralai_modelname_to_contextsize,
)
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ToolCall
if TYPE_CHECKING:
from llama_index.core.tools.types import BaseTool
from llama_index.core.chat_engine.types import AgentChatResponse
DEFAULT_MISTRALAI_MODEL = "mistral-tiny"
DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai"
......@@ -168,6 +174,7 @@ class MistralAI(LLM):
model_name=self.model,
safe_mode=self.safe_mode,
random_seed=self.random_seed,
is_function_calling_model=is_mistralai_function_calling_model(self.model),
)
@property
......@@ -200,9 +207,16 @@ class MistralAI(LLM):
]
all_kwargs = self._get_all_kwargs(**kwargs)
response = self._client.chat(messages=messages, **all_kwargs)
tool_calls = response.choices[0].message.tool_calls
return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT, content=response.choices[0].message.content
role=MessageRole.ASSISTANT,
content=response.choices[0].message.content,
additional_kwargs={"tool_calls": tool_calls}
if tool_calls is not None
else {},
),
raw=dict(response),
)
......@@ -252,6 +266,86 @@ class MistralAI(LLM):
stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat)
return stream_complete_fn(prompt, **kwargs)
def _get_tool_call(
self,
response: ChatResponse,
) -> ToolCall:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
if len(tool_calls) < 1:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
# TODO: support more than one tool call?
tool_call = tool_calls[0]
if not isinstance(tool_call, ToolCall):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by Mistralai.")
return tool_call
def _call_tool(
self,
tool_call: ToolCall,
tools_by_name: Dict[str, "BaseTool"],
verbose: bool = False,
) -> "AgentChatResponse":
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.calling import call_tool
arguments_str = tool_call.function.arguments
name = tool_call.function.name
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = tools_by_name[name]
argument_dict = json.loads(arguments_str)
tool_output = call_tool(tool, argument_dict)
return AgentChatResponse(response=tool_output.content, sources=[tool_output])
def predict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
if not self.metadata.is_function_calling_model:
return super().predict_and_call(
tools,
user_msg=user_msg,
chat_history=chat_history,
verbose=verbose,
**kwargs,
)
# misralai uses the same openai tool format
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]
tools_by_name = {tool.metadata.name: tool for tool in tools}
if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
messages = chat_history or []
if user_msg:
messages.append(user_msg)
response = self.chat(
messages,
tools=tool_specs,
**kwargs,
)
tool_call = self._get_tool_call(response)
return self._call_tool(tool_call, tools_by_name, verbose=verbose)
@llm_chat_callback()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
......@@ -315,3 +409,56 @@ class MistralAI(LLM):
) -> CompletionResponseAsyncGen:
astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
return await astream_complete_fn(prompt, **kwargs)
async def _acall_tool(
self,
tool_call: ToolCall,
tools_by_name: Dict[str, "BaseTool"],
verbose: bool = False,
) -> "AgentChatResponse":
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.calling import acall_tool
arguments_str = tool_call.function.arguments
name = tool_call.function.name
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = tools_by_name[name]
argument_dict = json.loads(arguments_str)
tool_output = await acall_tool(tool, argument_dict)
return AgentChatResponse(response=tool_output.content, sources=[tool_output])
async def apredict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
if not self.metadata.is_function_calling_model:
return await super().apredict_and_call(user_msg, tools, verbose, **kwargs)
# misralai uses the same openai tool format
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]
tools_by_name = {tool.metadata.name: tool for tool in tools}
if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
messages = chat_history or []
if user_msg:
messages.append(user_msg)
response = await self.achat(
messages,
tools=tool_specs,
**kwargs,
)
tool_call = self._get_tool_call(response)
return await self._acall_tool(tool_call, tools_by_name, verbose=verbose)
......@@ -12,6 +12,8 @@ MISTRALAI_MODELS: Dict[str, int] = {
"mistral-large-latest": 32000,
}
MISTRALAI_FUNCTION_CALLING_MODELS = ("mistral-large-latest",)
def mistralai_modelname_to_contextsize(modelname: str) -> int:
if modelname not in MISTRALAI_MODELS:
......@@ -21,3 +23,7 @@ def mistralai_modelname_to_contextsize(modelname: str) -> int:
)
return MISTRALAI_MODELS[modelname]
def is_mistralai_function_calling_model(modelname: str) -> bool:
return modelname in MISTRALAI_FUNCTION_CALLING_MODELS
......@@ -7,10 +7,14 @@ from typing import (
Optional,
Protocol,
Sequence,
Union,
cast,
get_args,
runtime_checkable,
TYPE_CHECKING,
)
import json
import httpx
import tiktoken
from llama_index.core.base.llms.types import (
......@@ -46,6 +50,7 @@ from llama_index.core.base.llms.generic_utils import (
from llama_index.core.llms.llm import LLM
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.openai.utils import (
OpenAIToolCall,
create_retry_decorator,
from_openai_message,
from_openai_token_logprobs,
......@@ -65,6 +70,10 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
)
if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.types import BaseTool
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
llm_retry_decorator = create_retry_decorator(
......@@ -544,6 +553,85 @@ class OpenAI(LLM):
"total_tokens": usage.get("total_tokens", 0),
}
def _get_tool_call(
self,
response: ChatResponse,
) -> OpenAIToolCall:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
if len(tool_calls) < 1:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
# TODO: support more than one tool call?
tool_call = tool_calls[0]
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")
return tool_call
def _call_tool(
self,
tool_call: OpenAIToolCall,
tools_by_name: Dict[str, "BaseTool"],
verbose: bool = False,
) -> "AgentChatResponse":
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.calling import call_tool
arguments_str = tool_call.function.arguments
name = tool_call.function.name
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = tools_by_name[name]
argument_dict = json.loads(arguments_str)
tool_output = call_tool(tool, argument_dict)
return AgentChatResponse(response=tool_output.content, sources=[tool_output])
def predict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
if not self.metadata.is_function_calling_model:
return super().predict_and_call(
tools,
user_msg=user_msg,
chat_history=chat_history,
verbose=verbose,
**kwargs,
)
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]
tools_by_name = {tool.metadata.name: tool for tool in tools}
if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
messages = chat_history or []
if user_msg:
messages.append(user_msg)
response = self.chat(
messages,
tools=tool_specs,
**kwargs,
)
tool_call = self._get_tool_call(response)
return self._call_tool(tool_call, tools_by_name, verbose=verbose)
# ===== Async Endpoints =====
@llm_chat_callback()
async def achat(
......@@ -731,3 +819,55 @@ class OpenAI(LLM):
)
return gen()
async def _acall_tool(
self,
tool_call: OpenAIToolCall,
tools_by_name: Dict[str, "BaseTool"],
verbose: bool = False,
) -> "AgentChatResponse":
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.calling import acall_tool
arguments_str = tool_call.function.arguments
name = tool_call.function.name
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = tools_by_name[name]
argument_dict = json.loads(arguments_str)
tool_output = await acall_tool(tool, argument_dict)
return AgentChatResponse(response=tool_output.content, sources=[tool_output])
async def apredict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
if not self.metadata.is_function_calling_model:
return await super().apredict_and_call(user_msg, tools, verbose, **kwargs)
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]
tools_by_name = {tool.metadata.name: tool for tool in tools}
if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
messages = chat_history or []
if user_msg:
messages.append(user_msg)
response = await self.achat(
messages,
tools=tool_specs,
**kwargs,
)
tool_call = self._get_tool_call(response)
return await self._acall_tool(tool_call, tools_by_name, verbose=verbose)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment