diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py b/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py index dfff8ad7769fd10456ba07f6ddb8d784c5c31728..7e0530f395b055715968a933765f740cc4a189f1 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py +++ b/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py @@ -1,8 +1,19 @@ import logging -from typing import Any, Callable, Dict, Optional, Sequence, Union - from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info from huggingface_hub.hf_api import ModelInfo +from huggingface_hub.inference._generated.types import ( + ChatCompletionOutput, + ChatCompletionStreamOutput, + ChatCompletionOutputToolCall, +) +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + + +from llama_index.core.base.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, + astream_completion_response_to_chat_response, +) from llama_index.core.base.llms.types import ( ChatMessage, ChatResponse, @@ -14,18 +25,20 @@ from llama_index.core.base.llms.types import ( LLMMetadata, MessageRole, ) +from llama_index.core.llms.llm import ToolSelection from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.constants import ( DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS, ) -from llama_index.core.llms.custom import CustomLLM +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.tools.types import BaseTool logger = logging.getLogger(__name__) -class HuggingFaceInferenceAPI(CustomLLM): +class HuggingFaceInferenceAPI(FunctionCallingLLM): """ Wrapper on the Hugging Face's Inference API. @@ -48,7 +61,17 @@ class HuggingFaceInferenceAPI(CustomLLM): def class_name(cls) -> str: return "HuggingFaceInferenceAPI" - # Corresponds with huggingface_hub.InferenceClient + model: Optional[str] = Field( + default=None, + description=( + "The model to run inference with. Can be a model id hosted on the Hugging" + " Face Hub, e.g. bigcode/starcoder or a URL to a deployed Inference" + " Endpoint. Defaults to None, in which case a recommended model is" + " automatically selected for the task (see Field below)." + ), + ) + + # TODO: deprecate this field model_name: Optional[str] = Field( default=None, description=( @@ -92,9 +115,9 @@ class HuggingFaceInferenceAPI(CustomLLM): ), ) - _sync_client: "InferenceClient" = PrivateAttr() - _async_client: "AsyncInferenceClient" = PrivateAttr() - _get_model_info: "Callable[..., ModelInfo]" = PrivateAttr() + _sync_client: InferenceClient = PrivateAttr() + _async_client: AsyncInferenceClient = PrivateAttr() + _get_model_info: Callable[..., ModelInfo] = PrivateAttr() context_window: int = Field( default=DEFAULT_CONTEXT_WINDOW, @@ -107,6 +130,11 @@ class HuggingFaceInferenceAPI(CustomLLM): default=DEFAULT_NUM_OUTPUTS, description=LLMMetadata.model_fields["num_output"].description, ) + temperature: float = Field( + default=0.1, + description="The temperature to use for the model.", + gt=0.0, + ) is_chat_model: bool = Field( default=False, description=( @@ -124,23 +152,9 @@ class HuggingFaceInferenceAPI(CustomLLM): ), ) - def _get_inference_client_kwargs(self) -> Dict[str, Any]: - """Extract the Hugging Face InferenceClient construction parameters.""" - return { - "model": self.model_name, - "token": self.token, - "timeout": self.timeout, - "headers": self.headers, - "cookies": self.cookies, - } - def __init__(self, **kwargs: Any) -> None: - """Initialize. - - Args: - kwargs: See the class-level Fields. - """ - if kwargs.get("model_name") is None: + model_name = kwargs.get("model_name") or kwargs.get("model") + if model_name is None: task = kwargs.get("task", "") # NOTE: task being None or empty string leads to ValueError, # which ensures model is present @@ -149,6 +163,7 @@ class HuggingFaceInferenceAPI(CustomLLM): f"Using Hugging Face's recommended model {kwargs['model_name']}" f" given task {task}." ) + if kwargs.get("task") is None: task = "conversational" else: @@ -157,31 +172,34 @@ class HuggingFaceInferenceAPI(CustomLLM): super().__init__(**kwargs) # Populate pydantic Fields self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) - self._get_model_info = model_info - - def validate_supported(self, task: str) -> None: - """ - Confirm the contained model_name is deployed on the Inference API service. - - Args: - task: Hugging Face task to check within. A list of all tasks can be - found here: https://huggingface.co/tasks - """ - all_models = self._sync_client.list_deployed_models(frameworks="all") - try: - if self.model_name not in all_models[task]: - raise ValueError( - "The Inference API service doesn't have the model" - f" {self.model_name!r} deployed." - ) - except KeyError as exc: - raise KeyError( - f"Input task {task!r} not in possible tasks {list(all_models.keys())}." - ) from exc + + def _get_inference_client_kwargs(self) -> Dict[str, Any]: + """Extract the Hugging Face InferenceClient construction parameters.""" + return { + "model": self.model_name or self.model, + "token": self.token, + "timeout": self.timeout, + "headers": self.headers, + "cookies": self.cookies, + } + + def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + base_kwargs = { + "model": self.model_name or self.model, + "max_tokens": self.num_output, + "temperature": self.temperature, + } + return {**base_kwargs, **kwargs} + + def _to_huggingface_messages( + self, messages: Sequence[ChatMessage] + ) -> List[Dict[str, Any]]: + return [{"role": m.role.value, "content": m.content} for m in messages] def get_model_info(self, **kwargs: Any) -> "ModelInfo": """Get metadata on the current model from Hugging Face.""" - return self._get_model_info(self.model_name, **kwargs) + model_name = self.model_name or self.model + return model_info(model_name, **kwargs) @property def metadata(self) -> LLMMetadata: @@ -190,92 +208,278 @@ class HuggingFaceInferenceAPI(CustomLLM): num_output=self.num_output, is_chat_model=self.is_chat_model, is_function_calling_model=self.is_function_calling_model, - model_name=self.model_name, + model_name=self.model_name or self.model, ) def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - # default to conversational task as that was the previous functionality if self.task == "conversational" or self.task is None: - output = self._sync_client.chat_completion( - messages=[ - {"role": m.role.value, "content": m.content} for m in messages - ], - model=self.model_name, - **kwargs, + model_kwargs = self._get_model_kwargs(**kwargs) + + output: ChatCompletionOutput = self._sync_client.chat_completion( + messages=self._to_huggingface_messages(messages), + **model_kwargs, ) + + content = output.choices[0].message.content or "" + tool_calls = output.choices[0].message.tool_calls or [] + additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} + return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, - content=output["choices"][0]["message"]["content"] or "", - ) + content=content, + additional_kwargs=additional_kwargs, + ), + raw=output, ) else: # try and use text generation prompt = self.messages_to_prompt(messages) - completion = self.complete(prompt) - return ChatResponse( - message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text) - ) + completion = self.complete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion) def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: + model_kwargs = self._get_model_kwargs(**kwargs) + model_kwargs["max_new_tokens"] = model_kwargs["max_tokens"] + del model_kwargs["max_tokens"] + + if not formatted: + prompt = self.completion_to_prompt(prompt) + return CompletionResponse( - text=self._sync_client.text_generation( - prompt, **{**{"max_new_tokens": self.num_output}, **kwargs} - ) + text=self._sync_client.text_generation(prompt, **model_kwargs) ) def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - raise NotImplementedError + if self.task == "conversational" or self.task is None: + model_kwargs = self._get_model_kwargs(**kwargs) + + def gen() -> ChatResponseGen: + response = "" + for chunk in self._sync_client.chat_completion( + messages=self._to_huggingface_messages(messages), + stream=True, + **model_kwargs, + ): + chunk: ChatCompletionStreamOutput = chunk + + delta = chunk.choices[0].delta.content or "" + response += delta + tool_calls = chunk.choices[0].delta.tool_calls or [] + additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} + + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=response, + additional_kwargs=additional_kwargs, + ), + delta=delta, + raw=chunk, + ) + + return gen() + else: + prompt = self.messages_to_prompt(messages) + completion_stream = self.stream_complete(prompt, formatted=True, **kwargs) + return stream_completion_response_to_chat_response(completion_stream) def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: - raise NotImplementedError + model_kwargs = self._get_model_kwargs(**kwargs) + model_kwargs["max_new_tokens"] = model_kwargs["max_tokens"] + del model_kwargs["max_tokens"] + + if not formatted: + prompt = self.completion_to_prompt(prompt) + + def gen() -> CompletionResponseGen: + response = "" + for delta in self._sync_client.text_generation( + prompt, stream=True, **model_kwargs + ): + response += delta + yield CompletionResponse(text=response, delta=delta) + + return gen() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponse: - raise NotImplementedError + if self.task == "conversational" or self.task is None: + model_kwargs = self._get_model_kwargs(**kwargs) + + output: ChatCompletionOutput = await self._async_client.chat_completion( + messages=self._to_huggingface_messages(messages), + **model_kwargs, + ) + + content = output.choices[0].message.content or "" + tool_calls = output.choices[0].message.tool_calls or [] + additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} + + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs=additional_kwargs, + ), + raw=output, + ) + else: + # try and use text generation + prompt = self.messages_to_prompt(messages) + completion = await self.acomplete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion) async def acomplete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: - response = await self._async_client.text_generation( - prompt, **{**{"max_new_tokens": self.num_output}, **kwargs} + model_kwargs = self._get_model_kwargs(**kwargs) + model_kwargs["max_new_tokens"] = model_kwargs["max_tokens"] + del model_kwargs["max_tokens"] + + if not formatted: + prompt = self.completion_to_prompt(prompt) + + return CompletionResponse( + text=await self._async_client.text_generation(prompt, **model_kwargs) ) - return CompletionResponse(text=response) async def astream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseAsyncGen: - # default to conversational task as that was the previous functionality if self.task == "conversational" or self.task is None: - output = await self._async_client.chat_completion( - messages=[ - {"role": m.role.value, "content": m.content} for m in messages - ], - model=self.model_name, - **kwargs, - ) - return ChatResponse( - message=ChatMessage( - role=MessageRole.ASSISTANT, - content=output["choices"][0]["message"]["content"] or "", - ) - ) + model_kwargs = self._get_model_kwargs(**kwargs) + + async def gen() -> ChatResponseAsyncGen: + response = "" + async for chunk in await self._async_client.chat_completion( + messages=self._to_huggingface_messages(messages), + stream=True, + **model_kwargs, + ): + chunk: ChatCompletionStreamOutput = chunk + + delta = chunk.choices[0].delta.content or "" + response += delta + tool_calls = chunk.choices[0].delta.tool_calls or [] + additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} + + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=response, + additional_kwargs=additional_kwargs, + ), + delta=delta, + raw=chunk, + ) + + await self._async_client.close() + + return gen() else: - # try and use text generation prompt = self.messages_to_prompt(messages) - completion = await self.acomplete(prompt) - return ChatResponse( - message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text) + completion_stream = await self.astream_complete( + prompt, formatted=True, **kwargs ) + return astream_completion_response_to_chat_response(completion_stream) async def astream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseAsyncGen: - raise NotImplementedError + model_kwargs = self._get_model_kwargs(**kwargs) + model_kwargs["max_new_tokens"] = model_kwargs["max_tokens"] + del model_kwargs["max_tokens"] + + if not formatted: + prompt = self.completion_to_prompt(prompt) + + async def gen() -> CompletionResponseAsyncGen: + response = "" + async for delta in await self._async_client.text_generation( + prompt, stream=True, **model_kwargs + ): + response += delta + yield CompletionResponse(text=response, delta=delta) + + await self._async_client.close() + + return gen() + + def _prepare_chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + tool_specs = [ + tool.metadata.to_openai_tool(skip_length_check=True) for tool in tools + ] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + return { + "messages": messages, + "tools": tool_specs or None, + } + + def _validate_chat_with_tools_response( + self, + response: ChatResponse, + tools: List["BaseTool"], + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Validate the response from chat_with_tools.""" + if not allow_parallel_tool_calls and response.message.additional_kwargs.get( + "tool_calls", [] + ): + response.additional_kwargs[ + "tool_calls" + ] = response.message.additional_kwargs["tool_calls"][0] + + return response + + def get_tool_calls_from_response( + self, + response: "ChatResponse", + error_on_no_tool_call: bool = True, + ) -> List[ToolSelection]: + """Predict and call the tool.""" + tool_calls: List[ + ChatCompletionOutputToolCall + ] = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + tool_selections.append( + ToolSelection( + tool_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) + + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface-api/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-huggingface-api/pyproject.toml index 6670b0150a0c0557599fe63c6d8d82783665f976..fcd5d95b5210bd8d383541f613a8c2013ef759bc 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface-api/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-huggingface-api/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-huggingface-api" readme = "README.md" -version = "0.3.1" +version = "0.4.0" [tool.poetry.dependencies] python = ">=3.9,<4.0"