From 57edd0367b4055654ea778655aa1499f95723678 Mon Sep 17 00:00:00 2001
From: Janaka Abeywardhana <contact@janaka.co.uk>
Date: Tue, 19 Mar 2024 15:27:35 +0000
Subject: [PATCH] fix(OptimumEmbedding): removing `token_type_ids` causing ONNX
 validation error (#12015)

---
 .../embeddings/huggingface_optimum/base.py       |  3 ---
 .../pyproject.toml                               |  2 +-
 .../tests/test_embeddings_huggingface_optimum.py | 16 ++++++++++++++++
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py
index 72f0f1aca5..629f872ab0 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/llama_index/embeddings/huggingface_optimum/base.py
@@ -133,9 +133,6 @@ class OptimumEmbedding(BaseEmbedding):
             return_tensors="pt",
         )
 
-        # pop token_type_ids
-        encoded_input.pop("token_type_ids", None)
-
         model_output = self._model(**encoded_input)
 
         if self.pooling == "cls":
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml
index 8d39afb840..3f2e02f390 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-embeddings-huggingface-optimum"
 readme = "README.md"
-version = "0.1.3"
+version = "0.1.4"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py
index dba6f36fd2..5c1b57df5a 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-optimum/tests/test_embeddings_huggingface_optimum.py
@@ -5,3 +5,19 @@ from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
 def test_optimumembedding_class():
     names_of_base_classes = [b.__name__ for b in OptimumEmbedding.__mro__]
     assert BaseEmbedding.__name__ in names_of_base_classes
+
+
+def test_optimumembedding_get_text_embedding(tmp_path):
+    model_dir = str(tmp_path / "models/bge_onnx")
+    OptimumEmbedding.create_and_save_optimum_model("BAAI/bge-small-en-v1.5", model_dir)
+    embed_model = OptimumEmbedding(folder_name=model_dir)
+    embeddings = embed_model.get_text_embedding("Hello World!")
+
+    assert len(embeddings) == 384
+    assert embeddings[:5] == [
+        -0.0032756966538727283,
+        -0.011690770275890827,
+        0.04155917093157768,
+        -0.038148097693920135,
+        0.024183034896850586,
+    ]
-- 
GitLab