Skip to content
Snippets Groups Projects
Unverified Commit 57edd036 authored by Janaka Abeywardhana's avatar Janaka Abeywardhana Committed by GitHub
Browse files

fix(OptimumEmbedding): removing `token_type_ids` causing ONNX validation error (#12015)

parent 05595b18
No related branches found
No related tags found
No related merge requests found
......@@ -133,9 +133,6 @@ class OptimumEmbedding(BaseEmbedding):
return_tensors="pt",
)
# pop token_type_ids
encoded_input.pop("token_type_ids", None)
model_output = self._model(**encoded_input)
if self.pooling == "cls":
......
......@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-huggingface-optimum"
readme = "README.md"
version = "0.1.3"
version = "0.1.4"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
......@@ -5,3 +5,19 @@ from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
def test_optimumembedding_class():
names_of_base_classes = [b.__name__ for b in OptimumEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes
def test_optimumembedding_get_text_embedding(tmp_path):
model_dir = str(tmp_path / "models/bge_onnx")
OptimumEmbedding.create_and_save_optimum_model("BAAI/bge-small-en-v1.5", model_dir)
embed_model = OptimumEmbedding(folder_name=model_dir)
embeddings = embed_model.get_text_embedding("Hello World!")
assert len(embeddings) == 384
assert embeddings[:5] == [
-0.0032756966538727283,
-0.011690770275890827,
0.04155917093157768,
-0.038148097693920135,
0.024183034896850586,
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment