diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index 0be189e1ff2009c9c555d624175bf5390ede4426..7fb33292e28f085bd4e5afb8cd727c2f2cfc07e8 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -395,6 +395,16 @@ def chat_messages_to_conversational_kwargs( return kwargs +def chat_messages_to_completion_prompt(messages: Sequence[ChatMessage]) -> str: + """Convert ChatMessages to a completion prompt.""" + return ( + "\n".join( + f"{message.role.capitalize()}: {message.content}" for message in messages + ) + + "\nAssistant:" + ) + + class HuggingFaceInferenceAPI(CustomLLM): """ Wrapper on the Hugging Face's Inference API. @@ -529,6 +539,10 @@ class HuggingFaceInferenceAPI(CustomLLM): f"Using Hugging Face's recommended model {kwargs['model_name']}" f" given task {task}." ) + if kwargs.get("task") is None: + task = "conversational" + else: + task = kwargs["task"].lower() super().__init__(**kwargs) # Populate pydantic Fields self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) @@ -569,14 +583,23 @@ class HuggingFaceInferenceAPI(CustomLLM): ) def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - output: "ConversationalOutput" = self._sync_client.conversational( - **{**chat_messages_to_conversational_kwargs(messages), **kwargs} - ) - return ChatResponse( - message=ChatMessage( - role=MessageRole.ASSISTANT, content=output["generated_text"] + # default to conversational task as that was the previous functionality + if self.task == "conversational" or self.task is None: + output: "ConversationalOutput" = self._sync_client.conversational( + **{**chat_messages_to_conversational_kwargs(messages), **kwargs} + ) + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=output["generated_text"] + ) + ) + else: + # try and use text generation + prompt = chat_messages_to_completion_prompt(messages=messages) + completion = self.complete(prompt) + return ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text) ) - ) def complete( self, prompt: str, formatted: bool = False, **kwargs: Any diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index 8c6eebbb61399cc9fac9915378c870caa98d5371..f90a22e4bca6a0179154de1abf1dcd202a892612 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -62,6 +62,7 @@ class TestHuggingFaceInferenceAPI: return_value=conversational_return, ) as mock_conversational: response = hf_inference_api.chat(messages=messages) + assert response.message.role == MessageRole.ASSISTANT assert response.message.content == generated_response mock_conversational.assert_called_once_with( @@ -70,6 +71,32 @@ class TestHuggingFaceInferenceAPI: generated_responses=["It's Die Hard for sure."], ) + def test_chat_text_generation( + self, hf_inference_api: HuggingFaceInferenceAPI + ) -> None: + hf_inference_api.task = "text-generation" + messages = [ + ChatMessage( + role=MessageRole.SYSTEM, content="You are an expert movie reviewer" + ), + ChatMessage(role=MessageRole.USER, content="Which movie is the best?"), + ] + conversational_return = "It's Die Hard for sure." + + with patch.object( + hf_inference_api._sync_client, + "text_generation", + return_value=conversational_return, + ) as mock_complete: + response = hf_inference_api.chat(messages=messages) + print(response) + assert response.message.role == MessageRole.ASSISTANT + assert response.message.content == conversational_return + mock_complete.assert_called_once_with( + "System: You are an expert movie reviewer\nUser: Which movie is the best?\nAssistant:", + max_new_tokens=256, + ) + def test_complete(self, hf_inference_api: HuggingFaceInferenceAPI) -> None: prompt = "My favorite color is " generated_text = '"green" and I love to paint. I have been painting for 30 years and have been'