diff --git a/docs/docs/examples/llm/huggingface.ipynb b/docs/docs/examples/llm/huggingface.ipynb index 51d925e0c1f04e37c2e2be708a5cd6ab1274d528..86251178825280ef546ca254342c15abf49eea6f 100644 --- a/docs/docs/examples/llm/huggingface.ipynb +++ b/docs/docs/examples/llm/huggingface.ipynb @@ -210,6 +210,167 @@ "\n", "Both of the above two subclass `llama_index.embeddings.base.BaseEmbedding`." ] + }, + { + "cell_type": "markdown", + "id": "92c09b9f", + "metadata": {}, + "source": [ + "### Using Hugging Face `text-generaton-inference`" + ] + }, + { + "cell_type": "markdown", + "id": "752520ec", + "metadata": {}, + "source": [ + "The new `TextGenerationInference` class allows to interface with endpoints running [`text-generation-inference`, TGI](https://huggingface.co/docs/text-generation-inference/index). In addition to blazingly fast inference, it supports `tool` usage starting from version `2.0.1`. " + ] + }, + { + "cell_type": "markdown", + "id": "055ddcb1", + "metadata": {}, + "source": [ + "To initialize an instance of `TextGenerationInference`, you need to provide the endpoint URL (self-hosted instance of TGI or public Inference Endpoint on Hugging Face created with TGI). In case of private Inference Endpoint, it is necessary to provide your HF token (either as initialization argument or environment variable)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c02f350f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " beyond! This phrase is a reference to the famous line from the movie \"Toy Story\" when Buzz Lightyear, a toy astronaut, exclaims \"To infinity and beyond!\" as he soars through space. It has since become a catchphrase for reaching for the stars and striving for greatness. However, if you meant to ask a mathematical question, \"To infinity\" refers to a very large, infinite number, and \"and beyond\" could be interpreted as continuing infinitely in a certain direction. For example, \"2 to the power of infinity\" would represent a very large, infinite number.\n" + ] + } + ], + "source": [ + "import os\n", + "from typing import List, Optional\n", + "\n", + "from llama_index.llms.huggingface import (\n", + " TextGenerationInference,\n", + ")\n", + "\n", + "URL = \"your_tgi_endpoint\"\n", + "model = TextGenerationInference(\n", + " model_url=URL, token=False\n", + ") # set token to False in case of public endpoint\n", + "\n", + "completion_response = model.complete(\"To infinity, and\")\n", + "print(completion_response)" + ] + }, + { + "cell_type": "markdown", + "id": "e9270b99", + "metadata": {}, + "source": [ + "To use tools with the `TextGenerationInference`, you may use an already existing tool or define your own:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90a041cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'tool_calls': [{'id': 0, 'type': 'function', 'function': {'description': None, 'name': 'get_current_weather_n_days', 'arguments': {'format': 'celsius', 'location': 'Paris, Ile-de-France', 'num_days': 7}}}]}\n" + ] + } + ], + "source": [ + "from typing import List, Literal\n", + "from llama_index.core.bridge.pydantic import BaseModel, Field\n", + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.core.base.llms.types import (\n", + " ChatMessage,\n", + " MessageRole,\n", + ")\n", + "\n", + "\n", + "def get_current_weather(location: str, format: str):\n", + " \"\"\"Get the current weather\n", + "\n", + " Args:\n", + " location (str): The city and state, e.g. San Francisco, CA\n", + " format (str): The temperature unit to use ('celsius' or 'fahrenheit'). Infer this from the users location.\n", + " \"\"\"\n", + " ...\n", + "\n", + "\n", + "class WeatherArgs(BaseModel):\n", + " location: str = Field(\n", + " description=\"The city and region, e.g. Paris, Ile-de-France\"\n", + " )\n", + " format: Literal[\"fahrenheit\", \"celsius\"] = Field(\n", + " description=\"The temperature unit to use ('fahrenheit' or 'celsius'). Infer this from the location.\",\n", + " )\n", + "\n", + "\n", + "weather_tool = FunctionTool.from_defaults(\n", + " fn=get_current_weather,\n", + " name=\"get_current_weather\",\n", + " description=\"Get the current weather\",\n", + " fn_schema=WeatherArgs,\n", + ")\n", + "\n", + "\n", + "def get_current_weather_n_days(location: str, format: str, num_days: int):\n", + " \"\"\"Get the weather forecast for the next N days\n", + "\n", + " Args:\n", + " location (str): The city and state, e.g. San Francisco, CA\n", + " format (str): The temperature unit to use ('celsius' or 'fahrenheit'). Infer this from the users location.\n", + " num_days (int): The number of days for the weather forecast.\n", + " \"\"\"\n", + " ...\n", + "\n", + "\n", + "class ForecastArgs(BaseModel):\n", + " location: str = Field(\n", + " description=\"The city and region, e.g. Paris, Ile-de-France\"\n", + " )\n", + " format: Literal[\"fahrenheit\", \"celsius\"] = Field(\n", + " description=\"The temperature unit to use ('fahrenheit' or 'celsius'). Infer this from the location.\",\n", + " )\n", + " num_days: int = Field(\n", + " description=\"The duration for the weather forecast in days.\",\n", + " )\n", + "\n", + "\n", + "forecast_tool = FunctionTool.from_defaults(\n", + " fn=get_current_weather_n_days,\n", + " name=\"get_current_weather_n_days\",\n", + " description=\"Get the current weather for n days\",\n", + " fn_schema=ForecastArgs,\n", + ")\n", + "\n", + "usr_msg = ChatMessage(\n", + " role=MessageRole.USER,\n", + " content=\"What's the weather like in Paris over next week?\",\n", + ")\n", + "\n", + "response = model.chat_with_tools(\n", + " user_msg=usr_msg,\n", + " tools=[\n", + " weather_tool,\n", + " forecast_tool,\n", + " ],\n", + " tool_choice=\"get_current_weather_n_days\",\n", + ")\n", + "\n", + "print(response.message.additional_kwargs)" + ] } ], "metadata": { diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/__init__.py b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/__init__.py index fa15d472395c08385cdcc3fc451a2e1e6caa5a96..c22da6bbf1b6ad51a23625c5d35c5bf597091bc6 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/__init__.py +++ b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/__init__.py @@ -1,3 +1,7 @@ -from llama_index.llms.huggingface.base import HuggingFaceInferenceAPI, HuggingFaceLLM +from llama_index.llms.huggingface.base import ( + HuggingFaceInferenceAPI, + HuggingFaceLLM, + TextGenerationInference, +) -__all__ = ["HuggingFaceLLM", "HuggingFaceInferenceAPI"] +__all__ = ["HuggingFaceLLM", "HuggingFaceInferenceAPI", "TextGenerationInference"] diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/base.py b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/base.py index 11646b9205d3d81365c62be6ae099ec0d75fdee0..6cd806297a24f61ff0200090426c36a73f69ff73 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/base.py +++ b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/base.py @@ -20,6 +20,7 @@ from llama_index.core.base.llms.types import ( from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager from llama_index.core.constants import ( + DEFAULT_TEMPERATURE, DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS, ) @@ -27,22 +28,40 @@ from llama_index.core.llms.callbacks import ( llm_chat_callback, llm_completion_callback, ) +from llama_index.core.llms.llm import ToolSelection from llama_index.core.llms.custom import CustomLLM +from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.base.llms.generic_utils import ( completion_response_to_chat_response, stream_completion_response_to_chat_response, -) -from llama_index.core.base.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, + chat_to_completion_decorator, + achat_to_completion_decorator, + stream_chat_to_completion_decorator, + astream_chat_to_completion_decorator, + get_from_param_or_env, ) from llama_index.core.prompts.base import PromptTemplate from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from llama_index.core.chat_engine.types import AgentChatResponse +from llama_index.core.tools.types import BaseTool +from llama_index.llms.huggingface.utils import ( + to_tgi_messages, + force_single_tool_call, + resolve_tgi_function_call, + get_max_input_length, + resolve_tool_choice, +) from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, ) +from text_generation import ( + Client as TGIClient, + AsyncClient as TGIAsyncClient, +) DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b" @@ -665,3 +684,394 @@ class HuggingFaceInferenceAPI(CustomLLM): self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseAsyncGen: raise NotImplementedError + + +class TextGenerationInference(FunctionCallingLLM): + model_name: Optional[str] = Field( + default=None, + description=("The name of the model served at the TGI endpoint"), + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description=("The temperature to use for sampling."), + gte=0.0, + lte=1.0, + ) + max_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description=("The maximum number of tokens to generate."), + gt=0, + ) + token: Union[str, bool, None] = Field( + default=None, + description=( + "Hugging Face token. Will default to the locally saved token. Pass " + "token=False if you don’t want to send your token to the server." + ), + ) + timeout: float = Field( + default=120, description=("The timeout to use in seconds."), gte=0 + ) + max_retries: int = Field( + default=5, description=("The maximum number of API retries."), gte=0 + ) + headers: Optional[Dict[str, str]] = Field( + default=None, + description=( + "Additional headers to send to the server. By default only the" + " authorization headers are sent. Values in this dictionary" + " will override the default values." + ), + ) + cookies: Optional[Dict[str, str]] = Field( + default=None, description=("Additional cookies to send to the server.") + ) + seed: Optional[str] = Field( + default=None, description=("The random seed to use for sampling.") + ) + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description=("Additional kwargs for the TGI API.") + ) + + _sync_client: "TGIClient" = PrivateAttr() + _async_client: "TGIAsyncClient" = PrivateAttr() + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=("Maximum input length in tokens returned from TGI endpoint"), + ) + is_chat_model: bool = Field( + default=True, + description=( + LLMMetadata.__fields__["is_chat_model"].field_info.description + + " TGI makes use of chat templating," + " function call is available only for '/v1/chat/completions' route" + " of TGI endpoint" + ), + ) + is_function_calling_model: bool = Field( + default=False, + description=( + LLMMetadata.__fields__["is_function_calling_model"].field_info.description + + " 'text-generation-inference' supports function call" + " starting from v1.4.3" + ), + ) + + def __init__( + self, + model_url, + model_name: Optional[str] = None, + cookies: Optional[dict] = None, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: int = DEFAULT_NUM_OUTPUTS, + timeout: int = 120, + max_retries: int = 5, + seed: Optional[int] = None, + token: Optional[str] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ) -> None: + additional_kwargs = additional_kwargs or {} + callback_manager = callback_manager or CallbackManager([]) + + token = get_from_param_or_env("token", token, "HF_TOKEN", "") + + headers = {} + if token: + headers.update({"Authorization": f"Bearer {token}"}) + + self._sync_client = TGIClient( + base_url=model_url, + headers=headers, + cookies=cookies, + timeout=timeout, + ) + self._async_client = TGIAsyncClient( + base_url=model_url, + headers=headers, + cookies=cookies, + timeout=timeout, + ) + + try: + is_function_calling_model = resolve_tgi_function_call(model_url) + except Exception as e: + logger.warning(f"TGI client has no function call support: {e}") + is_function_calling_model = False + + context_window = get_max_input_length(model_url) or DEFAULT_CONTEXT_WINDOW + + super().__init__( + context_window=context_window, + temperature=temperature, + max_tokens=max_tokens, + additional_kwargs=additional_kwargs, + timeout=timeout, + max_retries=max_retries, + seed=seed, + model=model_name, + is_function_calling_model=is_function_calling_model, + callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + ) + + @classmethod + def class_name(cls) -> str: + return "TextGenerationInference" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_tokens, + is_chat_model=True, + model_name=self.model_name, + random_seed=self.seed, + is_function_calling_model=self.is_function_calling_model, + ) + + @property + def _model_kwargs(self) -> Dict[str, Any]: + base_kwargs = { + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "seed": self.seed, + } + return { + **base_kwargs, + **self.additional_kwargs, + } + + def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + return { + **self._model_kwargs, + **kwargs, + } + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + # convert to TGI Message + messages = to_tgi_messages(messages) + all_kwargs = self._get_all_kwargs(**kwargs) + response = self._sync_client.chat(messages=messages, **all_kwargs) + tool_calls = response.choices[0].message.tool_calls + + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=response.choices[0].message.content, + additional_kwargs={"tool_calls": tool_calls} + if tool_calls is not None + else {}, + ), + raw=dict(response), + ) + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + # convert to TGI Message + messages = to_tgi_messages(messages) + all_kwargs = self._get_all_kwargs(**kwargs) + response = self._sync_client.chat(messages=messages, stream=True, **all_kwargs) + + def generator() -> ChatResponseGen: + content = "" + role = MessageRole.ASSISTANT + for chunk in response: + content_delta = chunk.choices[0].delta.content + if content_delta is None: + continue + content += content_delta + yield ChatResponse( + message=ChatMessage(role=role, content=content), + delta=content_delta, + raw=chunk, + ) + + return generator() + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return stream_complete_fn(prompt, **kwargs) + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + # convert to TGI Message + messages = to_tgi_messages(messages) + all_kwargs = self._get_all_kwargs(**kwargs) + response = await self._async_client.chat(messages=messages, **all_kwargs) + tool_calls = response.choices[0].message.tool_calls + + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=response.choices[0].message.content, + additional_kwargs={"tool_calls": tool_calls} + if tool_calls is not None + else {}, + ), + raw=dict(response), + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + acomplete_fn = achat_to_completion_decorator(self.achat) + return await acomplete_fn(prompt, **kwargs) + + @llm_chat_callback() + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + # convert to TGI Message + messages = to_tgi_messages(messages) + all_kwargs = self._get_all_kwargs(**kwargs) + response = await self._async_client.chat( + messages=messages, stream=True, **all_kwargs + ) + + async def generator() -> ChatResponseAsyncGen: + content = "" + role = MessageRole.ASSISTANT + async for chunk in response: + content_delta = chunk.choices[0].delta.content + if content_delta is None: + continue + content += content_delta + yield ChatResponse( + message=ChatMessage(role=role, content=content), + delta=content_delta, + raw=chunk, + ) + + return generator() + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) + return await astream_complete_fn(prompt, **kwargs) + + def chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + tool_choice: str = "auto", + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + # use openai tool format + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = self.chat( + messages=messages, + tools=tool_specs, + tool_choice=resolve_tool_choice(tool_specs, tool_choice), + **kwargs, + ) + if not allow_parallel_tool_calls: + force_single_tool_call(response) + return response + + async def achat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + tool_choice: str = "auto", + **kwargs: Any, + ) -> ChatResponse: + # use openai tool format + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = self.achat( + messages=messages, + tools=tool_specs, + tool_choice=resolve_tool_choice(tool_specs, tool_choice), + **kwargs, + ) + if not allow_parallel_tool_calls: + force_single_tool_call(response) + return response + + def get_tool_calls_from_response( + self, + response: "AgentChatResponse", + error_on_no_tool_call: bool = True, + ) -> List[ToolSelection]: + """Predict and call the tool.""" + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + # TODO Add typecheck with ToolCall from TGI once the client is updated + if tool_call and (tc_type := tool_call["type"]) != "function": + raise ValueError( + f"Invalid tool type: got {tc_type}, expect 'function'." + ) + argument_dict = tool_call["function"]["parameters"] + + tool_selections.append( + ToolSelection( + tool_id=tool_call["id"], + tool_name=tool_call["function"][ + "name" + ], # NOTE for now the tool_name is hardcoded 'tools' in TGI + tool_kwargs=argument_dict, + ) + ) + + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/utils.py b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71843873d1e09e8013f1fdb5a86bb1a5d685bd28 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-huggingface/llama_index/llms/huggingface/utils.py @@ -0,0 +1,66 @@ +import requests +from packaging import version +from typing import Sequence, Union, List, Optional +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, +) +from text_generation.types import ( + Message, +) + + +def resolve_tgi_function_call(url: str) -> bool: + url = f"{url}/info" + model_info = dict(requests.get(url).json()) + tgi_version = model_info.get("version", None) + if version.parse(tgi_version) >= version.parse("2.0.1"): + return True + else: + raise ValueError( + "'text-generation-inference' version ", + f"incompatible with function call: {tgi_version}. ", + "Function call support was added in v2.0.1", + ) + + +def get_max_input_length(url: str) -> Union[int, None]: + url = f"{url}/info" + model_info = dict(requests.get(url).json()) + return model_info.get("max_input_length", None) + + +def to_tgi_messages(messages: Sequence[ChatMessage]) -> Sequence[Message]: + out_messages = [] + for m in messages: + tool_calls = m.additional_kwargs.get("tool_calls") + out_messages.append( + Message(role=m.role.value, content=m.content, tool_calls=tool_calls) + ) + + return out_messages + + +def force_single_tool_call(response: ChatResponse) -> None: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + + +def resolve_tool_choice( + tools: Optional[List[dict]] = None, tool_choice: str = "none" +) -> Union[str, dict]: + """Resolve tool choice. + + Check if tool_name exists in tools. + Note that unlike in OpenAI specification, 'auto' will ALWAYS choose the tool for you. + Set to 'none' explicitly if do not wish to use tool. + """ + valid_tool_choices = ["none", "auto"] + [t["function"]["name"] for t in tools or []] + + if tool_choice not in valid_tool_choices: + raise ValueError( + f"{tool_choice} is not a valid tool_choice. Must be one of {valid_tool_choices}" + ) + + return tool_choice diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-huggingface/pyproject.toml index 39dde2eb2d3c84e470e8dd1afc4734366819375d..0d3b2344f2597630f63fd6356ccaed95f37a9df2 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-huggingface/pyproject.toml @@ -28,13 +28,14 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-huggingface" readme = "README.md" -version = "0.1.4" +version = "0.1.5" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" huggingface-hub = "^0.20.3" torch = "^2.1.2" +text-generation = "^0.7.0" [tool.poetry.dependencies.transformers] extras = ["torch"]