Skip to content
Snippets Groups Projects
Unverified Commit 62d1a938 authored by Sammy Roberts's avatar Sammy Roberts Committed by GitHub
Browse files

Support custom prompt formatting for non-chat LLMS (#10466)

* fix message passing

* add custom formatter to tests

* simplyfy and use base class functionality to acomplish formatting
parent c2ed8e97
No related branches found
No related tags found
No related merge requests found
...@@ -395,16 +395,6 @@ def chat_messages_to_conversational_kwargs( ...@@ -395,16 +395,6 @@ def chat_messages_to_conversational_kwargs(
return kwargs return kwargs
def chat_messages_to_completion_prompt(messages: Sequence[ChatMessage]) -> str:
"""Convert ChatMessages to a completion prompt."""
return (
"\n".join(
f"{message.role.capitalize()}: {message.content}" for message in messages
)
+ "\nAssistant:"
)
class HuggingFaceInferenceAPI(CustomLLM): class HuggingFaceInferenceAPI(CustomLLM):
""" """
Wrapper on the Hugging Face's Inference API. Wrapper on the Hugging Face's Inference API.
...@@ -471,6 +461,7 @@ class HuggingFaceInferenceAPI(CustomLLM): ...@@ -471,6 +461,7 @@ class HuggingFaceInferenceAPI(CustomLLM):
" model_name is left as default of None." " model_name is left as default of None."
), ),
) )
_sync_client: "InferenceClient" = PrivateAttr() _sync_client: "InferenceClient" = PrivateAttr()
_async_client: "AsyncInferenceClient" = PrivateAttr() _async_client: "AsyncInferenceClient" = PrivateAttr()
_get_model_info: "Callable[..., ModelInfo]" = PrivateAttr() _get_model_info: "Callable[..., ModelInfo]" = PrivateAttr()
...@@ -543,6 +534,7 @@ class HuggingFaceInferenceAPI(CustomLLM): ...@@ -543,6 +534,7 @@ class HuggingFaceInferenceAPI(CustomLLM):
task = "conversational" task = "conversational"
else: else:
task = kwargs["task"].lower() task = kwargs["task"].lower()
super().__init__(**kwargs) # Populate pydantic Fields super().__init__(**kwargs) # Populate pydantic Fields
self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) self._sync_client = InferenceClient(**self._get_inference_client_kwargs())
self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs())
...@@ -595,7 +587,7 @@ class HuggingFaceInferenceAPI(CustomLLM): ...@@ -595,7 +587,7 @@ class HuggingFaceInferenceAPI(CustomLLM):
) )
else: else:
# try and use text generation # try and use text generation
prompt = chat_messages_to_completion_prompt(messages=messages) prompt = self.messages_to_prompt(messages)
completion = self.complete(prompt) completion = self.complete(prompt)
return ChatResponse( return ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text) message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text)
......
...@@ -74,7 +74,11 @@ class TestHuggingFaceInferenceAPI: ...@@ -74,7 +74,11 @@ class TestHuggingFaceInferenceAPI:
def test_chat_text_generation( def test_chat_text_generation(
self, hf_inference_api: HuggingFaceInferenceAPI self, hf_inference_api: HuggingFaceInferenceAPI
) -> None: ) -> None:
mock_message_to_prompt = MagicMock(
return_value="System: You are an expert movie reviewer\nUser: Which movie is the best?\nAssistant:"
)
hf_inference_api.task = "text-generation" hf_inference_api.task = "text-generation"
hf_inference_api.messages_to_prompt = mock_message_to_prompt
messages = [ messages = [
ChatMessage( ChatMessage(
role=MessageRole.SYSTEM, content="You are an expert movie reviewer" role=MessageRole.SYSTEM, content="You are an expert movie reviewer"
...@@ -89,7 +93,8 @@ class TestHuggingFaceInferenceAPI: ...@@ -89,7 +93,8 @@ class TestHuggingFaceInferenceAPI:
return_value=conversational_return, return_value=conversational_return,
) as mock_complete: ) as mock_complete:
response = hf_inference_api.chat(messages=messages) response = hf_inference_api.chat(messages=messages)
print(response)
hf_inference_api.messages_to_prompt.assert_called_once_with(messages)
assert response.message.role == MessageRole.ASSISTANT assert response.message.role == MessageRole.ASSISTANT
assert response.message.content == conversational_return assert response.message.content == conversational_return
mock_complete.assert_called_once_with( mock_complete.assert_called_once_with(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment