From 57edd0367b4055654ea778655aa1499f95723678 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana <contact@janaka.co.uk> Date: Tue, 19 Mar 2024 15:27:35 +0000 Subject: [PATCH] fix(OptimumEmbedding): removing `token_type_ids` causing ONNX validation error (#12015) --- .../embeddings/huggingface_optimum/base.py | 3 --- .../pyproject.toml | 2 +- .../tests/test_embeddings_huggingface_optimum.py | 16 ++++++++++++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py index 72f0f1aca5..629f872ab0 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py @@ -133,9 +133,6 @@ class OptimumEmbedding(BaseEmbedding): return_tensors="pt", ) - # pop token_type_ids - encoded_input.pop("token_type_ids", None) - model_output = self._model(**encoded_input) if self.pooling == "cls": diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml index 8d39afb840..3f2e02f390 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-huggingface-optimum" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py index dba6f36fd2..5c1b57df5a 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py @@ -5,3 +5,19 @@ from llama_index.embeddings.huggingface_optimum import OptimumEmbedding def test_optimumembedding_class(): names_of_base_classes = [b.__name__ for b in OptimumEmbedding.__mro__] assert BaseEmbedding.__name__ in names_of_base_classes + + +def test_optimumembedding_get_text_embedding(tmp_path): + model_dir = str(tmp_path / "models/bge_onnx") + OptimumEmbedding.create_and_save_optimum_model("BAAI/bge-small-en-v1.5", model_dir) + embed_model = OptimumEmbedding(folder_name=model_dir) + embeddings = embed_model.get_text_embedding("Hello World!") + + assert len(embeddings) == 384 + assert embeddings[:5] == [ + -0.0032756966538727283, + -0.011690770275890827, + 0.04155917093157768, + -0.038148097693920135, + 0.024183034896850586, + ] -- GitLab