diff --git a/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py b/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py index 7e0530f395b055715968a933765f740cc4a189f1..a20e1bf3909d91b7ae016618655d9cffb6be9d9b 100644 --- a/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py +++ b/llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py @@ -1,3 +1,4 @@ +import json import logging from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info from huggingface_hub.hf_api import ModelInfo @@ -5,6 +6,7 @@ from huggingface_hub.inference._generated.types import ( ChatCompletionOutput, ChatCompletionStreamOutput, ChatCompletionOutputToolCall, + ChatCompletionOutputFunctionDefinition, ) from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -136,20 +138,12 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): gt=0.0, ) is_chat_model: bool = Field( - default=False, - description=( - LLMMetadata.model_fields["is_chat_model"].description - + " Unless chat templating is intentionally applied, Hugging Face models" - " are not chat models." - ), + default=True, + description="Controls whether the chat or text generation methods are used.", ) is_function_calling_model: bool = Field( default=False, - description=( - LLMMetadata.model_fields["is_function_calling_model"].description - + " As of 10/17/2023, Hugging Face doesn't support function calling" - " messages." - ), + description="Controls whether the function calling methods are used.", ) def __init__(self, **kwargs: Any) -> None: @@ -169,10 +163,21 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): else: task = kwargs["task"].lower() + if kwargs.get("is_function_calling_model", False): + print( + "Function calling is currently not supported for Hugging Face Inference API, setting is_function_calling_model to False" + ) + kwargs["is_function_calling_model"] = False + super().__init__(**kwargs) # Populate pydantic Fields self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) + # set context window if not provided + info = self._sync_client.get_endpoint_info() + if "max_input_tokens" in info and kwargs.get("context_window") is None: + self.context_window = info["max_input_tokens"] + def _get_inference_client_kwargs(self) -> Dict[str, Any]: """Extract the Hugging Face InferenceClient construction parameters.""" return { @@ -194,7 +199,53 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): def _to_huggingface_messages( self, messages: Sequence[ChatMessage] ) -> List[Dict[str, Any]]: - return [{"role": m.role.value, "content": m.content} for m in messages] + hf_dicts = [] + for m in messages: + hf_dicts.append( + {"role": m.role.value, "content": m.content if m.content else ""} + ) + if m.additional_kwargs.get("tool_calls", []): + tool_call_dicts = [] + for tool_call in m.additional_kwargs["tool_calls"]: + function_dict = { + "name": tool_call.id, + "arguments": tool_call.function.arguments, + } + tool_call_dicts.append( + {"type": "function", "function": function_dict} + ) + + hf_dicts[-1]["tool_calls"] = tool_call_dicts + + if m.role == MessageRole.TOOL: + hf_dicts[-1]["name"] = m.additional_kwargs.get("tool_call_id") + + return hf_dicts + + def _parse_streaming_tool_calls( + self, tool_call_strs: List[str] + ) -> List[ToolSelection | str]: + tool_calls = [] + # Try to parse into complete objects, otherwise keep as strings + for tool_call_str in tool_call_strs: + try: + tool_call_dict = json.loads(tool_call_str) + args = tool_call_dict["function"] + name = args.pop("_name") + tool_calls.append( + ChatCompletionOutputToolCall( + id=name, + type="function", + function=ChatCompletionOutputFunctionDefinition( + arguments=args, + name=name, + ), + ) + ) + except Exception as e: + tool_calls.append(tool_call_str) + + return tool_calls def get_model_info(self, **kwargs: Any) -> "ModelInfo": """Get metadata on the current model from Hugging Face.""" @@ -260,6 +311,8 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): def gen() -> ChatResponseGen: response = "" + tool_call_strs = [] + cur_index = -1 for chunk in self._sync_client.chat_completion( messages=self._to_huggingface_messages(messages), stream=True, @@ -269,9 +322,18 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): delta = chunk.choices[0].delta.content or "" response += delta - tool_calls = chunk.choices[0].delta.tool_calls or [] + tool_call_delta = chunk.choices[0].delta.tool_calls + if tool_call_delta: + if tool_call_delta.index != cur_index: + cur_index = tool_call_delta.index + tool_call_strs.append(tool_call_delta.function.arguments) + else: + tool_call_strs[ + cur_index + ] += tool_call_delta.function.arguments + + tool_calls = self._parse_streaming_tool_calls(tool_call_strs) additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} - yield ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, @@ -359,16 +421,32 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): async def gen() -> ChatResponseAsyncGen: response = "" + tool_call_strs = [] + cur_index = -1 async for chunk in await self._async_client.chat_completion( messages=self._to_huggingface_messages(messages), stream=True, **model_kwargs, ): + if chunk.choices[0].finish_reason is not None: + break + chunk: ChatCompletionStreamOutput = chunk delta = chunk.choices[0].delta.content or "" response += delta - tool_calls = chunk.choices[0].delta.tool_calls or [] + tool_call_delta = chunk.choices[0].delta.tool_calls + if tool_call_delta: + if tool_call_delta.index != cur_index: + cur_index = tool_call_delta.index + tool_call_strs.append(tool_call_delta.function.arguments) + else: + tool_call_strs[ + cur_index + ] += tool_call_delta.function.arguments + + tool_calls = self._parse_streaming_tool_calls(tool_call_strs) + additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {} yield ChatResponse( @@ -474,6 +552,10 @@ class HuggingFaceInferenceAPI(FunctionCallingLLM): tool_selections = [] for tool_call in tool_calls: + # while streaming, tool_call is a string + if isinstance(tool_call, str): + continue + tool_selections.append( ToolSelection( tool_id=tool_call.id,