diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index 7fb33292e28f085bd4e5afb8cd727c2f2cfc07e8..befb4bcecbaf951b3734447fca6161b670909c60 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -395,16 +395,6 @@ def chat_messages_to_conversational_kwargs( return kwargs -def chat_messages_to_completion_prompt(messages: Sequence[ChatMessage]) -> str: - """Convert ChatMessages to a completion prompt.""" - return ( - "\n".join( - f"{message.role.capitalize()}: {message.content}" for message in messages - ) - + "\nAssistant:" - ) - - class HuggingFaceInferenceAPI(CustomLLM): """ Wrapper on the Hugging Face's Inference API. @@ -471,6 +461,7 @@ class HuggingFaceInferenceAPI(CustomLLM): " model_name is left as default of None." ), ) + _sync_client: "InferenceClient" = PrivateAttr() _async_client: "AsyncInferenceClient" = PrivateAttr() _get_model_info: "Callable[..., ModelInfo]" = PrivateAttr() @@ -543,6 +534,7 @@ class HuggingFaceInferenceAPI(CustomLLM): task = "conversational" else: task = kwargs["task"].lower() + super().__init__(**kwargs) # Populate pydantic Fields self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) @@ -595,7 +587,7 @@ class HuggingFaceInferenceAPI(CustomLLM): ) else: # try and use text generation - prompt = chat_messages_to_completion_prompt(messages=messages) + prompt = self.messages_to_prompt(messages) completion = self.complete(prompt) return ChatResponse( message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text) diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index f90a22e4bca6a0179154de1abf1dcd202a892612..d993286db337604a7680d19f1458babc9e94aeb5 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -74,7 +74,11 @@ class TestHuggingFaceInferenceAPI: def test_chat_text_generation( self, hf_inference_api: HuggingFaceInferenceAPI ) -> None: + mock_message_to_prompt = MagicMock( + return_value="System: You are an expert movie reviewer\nUser: Which movie is the best?\nAssistant:" + ) hf_inference_api.task = "text-generation" + hf_inference_api.messages_to_prompt = mock_message_to_prompt messages = [ ChatMessage( role=MessageRole.SYSTEM, content="You are an expert movie reviewer" @@ -89,7 +93,8 @@ class TestHuggingFaceInferenceAPI: return_value=conversational_return, ) as mock_complete: response = hf_inference_api.chat(messages=messages) - print(response) + + hf_inference_api.messages_to_prompt.assert_called_once_with(messages) assert response.message.role == MessageRole.ASSISTANT assert response.message.content == conversational_return mock_complete.assert_called_once_with(