diff --git a/CHANGELOG.md b/CHANGELOG.md index b538b5481538fac63df7991c193aa86c5485c38e..20890a21291d90cc18788ba530d3f83d98f3bdd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ - add semantic kernel document format (#8226) - Improve MyScale Hybrid Search and Add `DELETE` for MyScale vector store (#8159) +### New Features + +- Support for Hugging Face Inference API's `conversational`, `text_generation`, + and `feature_extraction` endpoints via `huggingface_hub[inference]` (#8098) + ### Bug Fixes / Nits - Fixed additional kwargs in ReActAgent.from_tools() (#8206) diff --git a/llama_index/embeddings/__init__.py b/llama_index/embeddings/__init__.py index cecca22c1ec490c371c00f4cf3d6d0334d18f688..ed535de1f03215eb033f3c40dafd0cb181de82f8 100644 --- a/llama_index/embeddings/__init__.py +++ b/llama_index/embeddings/__init__.py @@ -9,31 +9,37 @@ from llama_index.embeddings.clarifai import ClarifaiEmbedding from llama_index.embeddings.elasticsearch import ElasticsearchEmbeddings from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding from llama_index.embeddings.gradient import GradientEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.embeddings.huggingface import ( + HuggingFaceEmbedding, + HuggingFaceInferenceAPIEmbeddings, +) from llama_index.embeddings.huggingface_optimum import OptimumEmbedding from llama_index.embeddings.huggingface_utils import DEFAULT_HUGGINGFACE_EMBEDDING_MODEL from llama_index.embeddings.instructor import InstructorEmbedding from llama_index.embeddings.langchain import LangchainEmbedding from llama_index.embeddings.llm_rails import LLMRailsEmbeddings from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.pooling import Pooling from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference from llama_index.embeddings.utils import resolve_embed_model __all__ = [ - "GoogleUnivSentEncoderEmbedding", - "LangchainEmbedding", - "OpenAIEmbedding", - "LinearAdapterEmbeddingModel", "AdapterEmbeddingModel", + "ClarifaiEmbedding", + "DEFAULT_HUGGINGFACE_EMBEDDING_MODEL", + "ElasticsearchEmbeddings", + "GoogleUnivSentEncoderEmbedding", + "GradientEmbedding", + "HuggingFaceInferenceAPIEmbeddings", "HuggingFaceEmbedding", "InstructorEmbedding", + "LangchainEmbedding", + "LinearAdapterEmbeddingModel", + "LLMRailsEmbeddings", + "OpenAIEmbedding", "OptimumEmbedding", - "resolve_embed_model", - "DEFAULT_HUGGINGFACE_EMBEDDING_MODEL", + "Pooling", "SimilarityMode", - "ElasticsearchEmbeddings", - "ClarifaiEmbedding", - "GradientEmbedding", "TextEmbeddingsInference", - "LLMRailsEmbeddings", + "resolve_embed_model", ] diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py index c5320234b05029ffa11e3394142527ba6b731349..e2e07d5aded6fb10761d3b92254b80f8e01e1843 100644 --- a/llama_index/embeddings/huggingface.py +++ b/llama_index/embeddings/huggingface.py @@ -1,13 +1,20 @@ -from typing import Any, List, Optional +import asyncio +from typing import Any, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.embeddings.base import ( + DEFAULT_EMBED_BATCH_SIZE, + BaseEmbedding, + Embedding, +) from llama_index.embeddings.huggingface_utils import ( DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, format_query, format_text, ) +from llama_index.embeddings.pooling import Pooling +from llama_index.llms.huggingface import HuggingFaceInferenceAPI from llama_index.utils import get_cache_dir, infer_torch_device @@ -176,3 +183,111 @@ class HuggingFaceEmbedding(BaseEmbedding): format_text(text, self.model_name, self.text_instruction) for text in texts ] return self._embed(texts) + + +class HuggingFaceInferenceAPIEmbeddings(HuggingFaceInferenceAPI, BaseEmbedding): # type: ignore[misc] + """ + Wrapper on the Hugging Face's Inference API for embeddings. + + Overview of the design: + - Uses the feature extraction task: https://huggingface.co/tasks/feature-extraction + """ + + pooling: Optional[Pooling] = Field( + default=Pooling.CLS, + description=( + "Optional pooling technique to use with embeddings capability, if" + " the model's raw output needs pooling." + ), + ) + query_instruction: Optional[str] = Field( + default=None, + description=( + "Instruction to prepend during query embedding." + " Use of None means infer the instruction based on the model." + " Use of empty string will defeat instruction prepending entirely." + ), + ) + text_instruction: Optional[str] = Field( + default=None, + description=( + "Instruction to prepend during text embedding." + " Use of None means infer the instruction based on the model." + " Use of empty string will defeat instruction prepending entirely." + ), + ) + + @classmethod + def class_name(cls) -> str: + return "HuggingFaceInferenceAPIEmbeddings" + + async def _async_embed_single(self, text: str) -> Embedding: + embedding = (await self._async_client.feature_extraction(text)).squeeze(axis=0) + if len(embedding.shape) == 1: # Some models pool internally + return list(embedding) + try: + return list(self.pooling(embedding)) # type: ignore[misc] + except TypeError as exc: + raise ValueError( + f"Pooling is required for {self.model_name} because it returned" + " a > 1-D value, please specify pooling as not None." + ) from exc + + async def _async_embed_bulk(self, texts: Sequence[str]) -> List[Embedding]: + """ + Embed a sequence of text, in parallel and asynchronously. + + NOTE: this uses an externally created asyncio event loop. + """ + tasks = [self._async_embed_single(text) for text in texts] + return await asyncio.gather(*tasks) + + def _get_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query synchronously. + + NOTE: a new asyncio event loop is created internally for this. + """ + return asyncio.run(self._aget_query_embedding(query)) + + def _get_text_embedding(self, text: str) -> Embedding: + """ + Embed the text query synchronously. + + NOTE: a new asyncio event loop is created internally for this. + """ + return asyncio.run(self._aget_text_embedding(text)) + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + """ + Embed the input sequence of text synchronously and in parallel. + + NOTE: a new asyncio event loop is created internally for this. + """ + loop = asyncio.new_event_loop() + try: + tasks = [ + loop.create_task(self._aget_text_embedding(text)) for text in texts + ] + loop.run_until_complete(asyncio.wait(tasks)) + finally: + loop.close() + return [task.result() for task in tasks] + + async def _aget_query_embedding(self, query: str) -> Embedding: + return await self._async_embed_single( + text=format_query(query, self.model_name, self.query_instruction) + ) + + async def _aget_text_embedding(self, text: str) -> Embedding: + return await self._async_embed_single( + text=format_text(text, self.model_name, self.text_instruction) + ) + + async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]: + return await self._async_embed_bulk( + texts=[ + format_text(text, self.model_name, self.text_instruction) + for text in texts + ] + ) diff --git a/llama_index/embeddings/huggingface_utils.py b/llama_index/embeddings/huggingface_utils.py index 7ab0d56e6e5b16365717cbe6edba49f25e179675..606bced13b6b616b57b14bfdc3e924b96feeffc7 100644 --- a/llama_index/embeddings/huggingface_utils.py +++ b/llama_index/embeddings/huggingface_utils.py @@ -38,20 +38,20 @@ INSTRUCTOR_MODELS = ( ) -def get_query_instruct_for_model_name(model_name: str) -> str: +def get_query_instruct_for_model_name(model_name: Optional[str]) -> str: """Get query text instruction for a given model name.""" if model_name in INSTRUCTOR_MODELS: return DEFAULT_QUERY_INSTRUCTION - elif model_name in BGE_MODELS: + if model_name in BGE_MODELS: if "zh" in model_name: return DEFAULT_QUERY_BGE_INSTRUCTION_ZH - else: - return DEFAULT_QUERY_BGE_INSTRUCTION_EN - else: - return "" + return DEFAULT_QUERY_BGE_INSTRUCTION_EN + return "" -def format_query(query: str, model_name: str, instruction: Optional[str] = None) -> str: +def format_query( + query: str, model_name: Optional[str], instruction: Optional[str] = None +) -> str: if instruction is None: instruction = get_query_instruct_for_model_name(model_name) # NOTE: strip() enables backdoor for defeating instruction prepend by @@ -59,12 +59,14 @@ def format_query(query: str, model_name: str, instruction: Optional[str] = None) return f"{instruction} {query}".strip() -def get_text_instruct_for_model_name(model_name: str) -> str: +def get_text_instruct_for_model_name(model_name: Optional[str]) -> str: """Get text instruction for a given model name.""" return DEFAULT_EMBED_INSTRUCTION if model_name in INSTRUCTOR_MODELS else "" -def format_text(text: str, model_name: str, instruction: Optional[str] = None) -> str: +def format_text( + text: str, model_name: Optional[str], instruction: Optional[str] = None +) -> str: if instruction is None: instruction = get_text_instruct_for_model_name(model_name) # NOTE: strip() enables backdoor for defeating instruction prepend by diff --git a/llama_index/embeddings/pooling.py b/llama_index/embeddings/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..08ec6f4e007c1424cba090220b6c5a1afc951f81 --- /dev/null +++ b/llama_index/embeddings/pooling.py @@ -0,0 +1,31 @@ +from enum import Enum + +import numpy as np + + +class Pooling(str, Enum): + """Enum of possible pooling choices with pooling behaviors.""" + + CLS = "cls" + MEAN = "mean" + + def __call__(self, array: np.ndarray) -> np.ndarray: + if self == self.CLS: + return self.cls_pooling(array) + return self.mean_pooling(array) + + @classmethod + def cls_pooling(cls, array: np.ndarray) -> np.ndarray: + if len(array.shape) == 3: + return array[:, 0] + if len(array.shape) == 2: + return array[0] + raise NotImplementedError(f"Unhandled shape {array.shape}.") + + @classmethod + def mean_pooling(cls, array: np.ndarray) -> np.ndarray: + if len(array.shape) == 3: + return array.mean(axis=1) + if len(array.shape) == 2: + return array.mean(axis=0) + raise NotImplementedError(f"Unhandled shape {array.shape}.") diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index 49cd0d6acf62d651388b062785fa7a8cb7232e46..e9c27380dfd818d0a73e96a80e1189b9921c31e0 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -18,7 +18,7 @@ from llama_index.llms.cohere import Cohere from llama_index.llms.custom import CustomLLM from llama_index.llms.everlyai import EverlyAI from llama_index.llms.gradient import GradientBaseModelLLM, GradientModelAdapterLLM -from llama_index.llms.huggingface import HuggingFaceLLM +from llama_index.llms.huggingface import HuggingFaceInferenceAPI, HuggingFaceLLM from llama_index.llms.konko import Konko from llama_index.llms.langchain import LangChainLLM from llama_index.llms.litellm import LiteLLM @@ -53,6 +53,7 @@ __all__ = [ "EverlyAI", "GradientBaseModelLLM", "GradientModelAdapterLLM", + "HuggingFaceInferenceAPI", "HuggingFaceLLM", "Konko", "LLMMetadata", diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index 0175a5342b8f6984f6756dcd88af240d46b01593..3705023007e4c7491e10c7a70d35dab39cf10006 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -1,16 +1,20 @@ import logging from threading import Thread -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.llms import ChatResponseAsyncGen, CompletionResponseAsyncGen from llama_index.llms.base import ( + LLM, ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, + MessageRole, llm_chat_callback, llm_completion_callback, ) @@ -24,6 +28,17 @@ from llama_index.llms.generic_utils import ( ) from llama_index.prompts.base import PromptTemplate +if TYPE_CHECKING: + try: + from huggingface_hub import AsyncInferenceClient, InferenceClient + from huggingface_hub.hf_api import ModelInfo + from huggingface_hub.inference._types import ConversationalOutput + except ModuleNotFoundError: + AsyncInferenceClient = Any + InferenceClient = Any + ConversationalOutput = dict + ModelInfo = Any + logger = logging.getLogger(__name__) @@ -124,7 +139,7 @@ class HuggingFaceLLM(CustomLLM): except ImportError as exc: raise ImportError( f"{type(self).__name__} requires torch and transformers packages.\n" - f"Please install both with `pip install transformers[torch]`." + "Please install both with `pip install transformers[torch]`." ) from exc model_kwargs = model_kwargs or {} @@ -294,3 +309,232 @@ class HuggingFaceLLM(CustomLLM): prompt = self._messages_to_prompt(messages) completion_response = self.stream_complete(prompt, formatted=True, **kwargs) return stream_completion_response_to_chat_response(completion_response) + + +def chat_messages_to_conversational_kwargs( + messages: Sequence[ChatMessage], +) -> Dict[str, Any]: + """Convert ChatMessages to keyword arguments for Inference API conversational.""" + if len(messages) % 2 != 1: + raise NotImplementedError("Messages passed in must be of odd length.") + last_message = messages[-1] + kwargs: Dict[str, Any] = { + "text": last_message.content, + **last_message.additional_kwargs, + } + if len(messages) != 1: + kwargs["past_user_inputs"] = [] + kwargs["generated_responses"] = [] + for user_msg, assistant_msg in zip(messages[::2], messages[1::2]): + if ( + user_msg.role != MessageRole.USER + or assistant_msg.role != MessageRole.ASSISTANT + ): + raise NotImplementedError( + "Didn't handle when messages aren't ordered in alternating" + f" pairs of {(MessageRole.USER, MessageRole.ASSISTANT)}." + ) + kwargs["past_user_inputs"].append(user_msg.content) + kwargs["generated_responses"].append(assistant_msg.content) + return kwargs + + +class HuggingFaceInferenceAPI(LLM): + """ + Wrapper on the Hugging Face's Inference API. + + Overview of the design: + - Synchronous uses InferenceClient, asynchronous uses AsyncInferenceClient + - chat uses the conversational task: https://huggingface.co/tasks/conversational + - complete uses the text generation task: https://huggingface.co/tasks/text-generation + + Note: some models that support the text generation task can leverage Hugging + Face's optimized deployment toolkit called text-generation-inference (TGI). + Use InferenceClient.get_model_status to check if TGI is being used. + + Relevant links: + - General Docs: https://huggingface.co/docs/api-inference/index + - API Docs: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client + - Source: https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub/inference + """ + + @classmethod + def class_name(cls) -> str: + return "HuggingFaceInferenceAPI" + + # Corresponds with huggingface_hub.InferenceClient + model_name: Optional[str] = Field( + default=None, + description=( + "The model to run inference with. Can be a model id hosted on the Hugging" + " Face Hub, e.g. bigcode/starcoder or a URL to a deployed Inference" + " Endpoint. Defaults to None, in which case a recommended model is" + " automatically selected for the task." + ), + ) + token: Union[str, bool, None] = Field( + default=None, + description=( + "Hugging Face token. Will default to the locally saved token. Pass " + "token=False if you don’t want to send your token to the server." + ), + ) + timeout: Optional[float] = Field( + default=None, + description=( + "The maximum number of seconds to wait for a response from the server." + " Loading a new model in Inference API can take up to several minutes." + " Defaults to None, meaning it will loop until the server is available." + ), + ) + headers: Dict[str, str] = Field( + default=None, + description=( + "Additional headers to send to the server. By default only the" + " authorization and user-agent headers are sent. Values in this dictionary" + " will override the default values." + ), + ) + cookies: Dict[str, str] = Field( + default=None, description="Additional cookies to send to the server." + ) + _sync_client: "InferenceClient" = PrivateAttr() + _async_client: "AsyncInferenceClient" = PrivateAttr() + _get_model_info: "Callable[..., ModelInfo]" = PrivateAttr() + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=( + LLMMetadata.__fields__["context_window"].field_info.description + + " This may be looked up in a model's `config.json`." + ), + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description=LLMMetadata.__fields__["num_output"].field_info.description, + ) + is_chat_model: bool = Field( + default=False, + description=( + LLMMetadata.__fields__["is_chat_model"].field_info.description + + " Unless chat templating is intentionally applied, Hugging Face models" + " are not chat models." + ), + ) + is_function_calling_model: bool = Field( + default=False, + description=( + LLMMetadata.__fields__["is_function_calling_model"].field_info.description + + " As of 10/17/2023, Hugging Face doesn't support function calling" + " messages." + ), + ) + + def _get_inference_client_kwargs(self) -> Dict[str, Any]: + """Extract the Hugging Face InferenceClient construction parameters.""" + return { + "model": self.model_name, + "token": self.token, + "timeout": self.timeout, + "headers": self.headers, + "cookies": self.cookies, + } + + def __init__(self, **kwargs: Any) -> None: + """Initialize. + + Args: + kwargs: See the class-level Fields. + """ + super().__init__(**kwargs) # Populate pydantic Fields + try: + from huggingface_hub import ( + AsyncInferenceClient, + InferenceClient, + model_info, + ) + except ModuleNotFoundError as exc: + raise ImportError( + f"{type(self).__name__} requires huggingface_hub with its inference" + " extras, please run `pip install huggingface_hub[inference]`." + ) from exc + self._sync_client = InferenceClient(**self._get_inference_client_kwargs()) + self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs()) + self._get_model_info = model_info + + def validate_supported(self, task: str) -> None: + """ + Confirm the contained model_name is deployed on the Inference API service. + + Args: + task: Hugging Face task to check within. A list of all tasks can be + found here: https://huggingface.co/tasks + """ + all_models = self._sync_client.list_deployed_models(frameworks="all") + try: + if self.model_name not in all_models[task]: + raise ValueError( + "The Inference API service doesn't have the model" + f" {self.model_name!r} deployed." + ) + except KeyError as exc: + raise KeyError( + f"Input task {task!r} not in possible tasks {list(all_models.keys())}." + ) from exc + + def get_model_info(self, **kwargs: Any) -> "ModelInfo": + """Get metadata on the current model from Hugging Face.""" + return self._get_model_info(self.model_name, **kwargs) + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=self.context_window, + num_output=self.num_output, + is_chat_model=self.is_chat_model, + is_function_calling_model=self.is_function_calling_model, + model_name=self.model_name, + ) + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + output: "ConversationalOutput" = self._sync_client.conversational( + **{**chat_messages_to_conversational_kwargs(messages), **kwargs} + ) + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=output["generated_text"] + ) + ) + + def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + return CompletionResponse( + text=self._sync_client.text_generation( + prompt, **{**{"max_new_tokens": self.num_output}, **kwargs} + ) + ) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError + + def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + raise NotImplementedError + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + raise NotImplementedError + + async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + raise NotImplementedError + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError + + async def astream_complete( + self, prompt: str, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError diff --git a/tests/embeddings/test_huggingface.py b/tests/embeddings/test_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..952d9be839f4372d7b9e9178d48d80a66d385e67 --- /dev/null +++ b/tests/embeddings/test_huggingface.py @@ -0,0 +1,73 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings +from llama_index.embeddings.pooling import Pooling + +from tests.llms.test_huggingface import STUB_MODEL_NAME + + +@pytest.fixture(name="hf_inference_api_embeddings") +def fixture_hf_inference_api_embeddings() -> HuggingFaceInferenceAPIEmbeddings: + with patch.dict("sys.modules", huggingface_hub=MagicMock()): + return HuggingFaceInferenceAPIEmbeddings(model_name=STUB_MODEL_NAME) + + +class TestHuggingFaceInferenceAPIEmbeddings: + def test_class_name( + self, hf_inference_api_embeddings: HuggingFaceInferenceAPIEmbeddings + ) -> None: + assert ( + HuggingFaceInferenceAPIEmbeddings.class_name() + == HuggingFaceInferenceAPIEmbeddings.__name__ + ) + assert ( + hf_inference_api_embeddings.class_name() + == HuggingFaceInferenceAPIEmbeddings.__name__ + ) + + def test_embed_query( + self, hf_inference_api_embeddings: HuggingFaceInferenceAPIEmbeddings + ) -> None: + raw_single_embedding = np.random.rand(1, 3, 1024) + + hf_inference_api_embeddings.pooling = Pooling.CLS + with patch.object( + hf_inference_api_embeddings._async_client, + "feature_extraction", + AsyncMock(return_value=raw_single_embedding), + ) as mock_feature_extraction: + embedding = hf_inference_api_embeddings.get_query_embedding("test") + assert isinstance(embedding, list) + assert len(embedding) == 1024 + assert np.all( + np.array(embedding, dtype=raw_single_embedding.dtype) + == raw_single_embedding[0, 0] + ) + mock_feature_extraction.assert_awaited_once_with("test") + + hf_inference_api_embeddings.pooling = Pooling.MEAN + with patch.object( + hf_inference_api_embeddings._async_client, + "feature_extraction", + AsyncMock(return_value=raw_single_embedding), + ) as mock_feature_extraction: + embedding = hf_inference_api_embeddings.get_query_embedding("test") + assert isinstance(embedding, list) + assert len(embedding) == 1024 + assert np.all( + np.array(embedding, dtype=raw_single_embedding.dtype) + == raw_single_embedding[0].mean(axis=0) + ) + mock_feature_extraction.assert_awaited_once_with("test") + + def test_serialization( + self, hf_inference_api_embeddings: HuggingFaceInferenceAPIEmbeddings + ) -> None: + serialized = hf_inference_api_embeddings.to_dict() + # Check Hugging Face Inference API base class specifics + assert serialized["model_name"] == STUB_MODEL_NAME + assert isinstance(serialized["context_window"], int) + # Check Hugging Face Inference API Embeddings derived class specifics + assert serialized["pooling"] == Pooling.CLS diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6eebbb61399cc9fac9915378c870caa98d5371 --- /dev/null +++ b/tests/llms/test_huggingface.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock, patch + +import pytest +from llama_index.llms import ChatMessage, MessageRole +from llama_index.llms.huggingface import HuggingFaceInferenceAPI + +STUB_MODEL_NAME = "placeholder_model" + + +@pytest.fixture(name="hf_inference_api") +def fixture_hf_inference_api() -> HuggingFaceInferenceAPI: + with patch.dict("sys.modules", huggingface_hub=MagicMock()): + return HuggingFaceInferenceAPI(model_name=STUB_MODEL_NAME) + + +class TestHuggingFaceInferenceAPI: + def test_class_name(self, hf_inference_api: HuggingFaceInferenceAPI) -> None: + assert HuggingFaceInferenceAPI.class_name() == HuggingFaceInferenceAPI.__name__ + assert hf_inference_api.class_name() == HuggingFaceInferenceAPI.__name__ + + def test_instantiation(self) -> None: + mock_hub = MagicMock() + with patch.dict("sys.modules", huggingface_hub=mock_hub): + llm = HuggingFaceInferenceAPI(model_name=STUB_MODEL_NAME) + + assert llm.model_name == STUB_MODEL_NAME + + # Check can be both a large language model and an embedding model + assert isinstance(llm, HuggingFaceInferenceAPI) + + # Confirm Clients are instantiated correctly + mock_hub.InferenceClient.assert_called_once_with( + model=STUB_MODEL_NAME, token=None, timeout=None, headers=None, cookies=None + ) + mock_hub.AsyncInferenceClient.assert_called_once_with( + model=STUB_MODEL_NAME, token=None, timeout=None, headers=None, cookies=None + ) + + def test_chat(self, hf_inference_api: HuggingFaceInferenceAPI) -> None: + messages = [ + ChatMessage(content="Which movie is the best?"), + ChatMessage(content="It's Die Hard for sure.", role=MessageRole.ASSISTANT), + ChatMessage(content="Can you explain why?"), + ] + generated_response = ( + " It's based on the book of the same name by James Fenimore Cooper." + ) + conversational_return = { + "generated_text": generated_response, + "conversation": { + "generated_responses": ["It's Die Hard for sure.", generated_response], + "past_user_inputs": [ + "Which movie is the best?", + "Can you explain why?", + ], + }, + } + + with patch.object( + hf_inference_api._sync_client, + "conversational", + return_value=conversational_return, + ) as mock_conversational: + response = hf_inference_api.chat(messages=messages) + assert response.message.role == MessageRole.ASSISTANT + assert response.message.content == generated_response + mock_conversational.assert_called_once_with( + text="Can you explain why?", + past_user_inputs=["Which movie is the best?"], + generated_responses=["It's Die Hard for sure."], + ) + + def test_complete(self, hf_inference_api: HuggingFaceInferenceAPI) -> None: + prompt = "My favorite color is " + generated_text = '"green" and I love to paint. I have been painting for 30 years and have been' + with patch.object( + hf_inference_api._sync_client, + "text_generation", + return_value=generated_text, + ) as mock_text_generation: + response = hf_inference_api.complete(prompt) + mock_text_generation.assert_called_once_with(prompt, max_new_tokens=256) + assert response.text == generated_text