From 029f74339832c829cd6d8fbd8ac0a170dd5857dd Mon Sep 17 00:00:00 2001
From: Ethan Yang <ethan.yang@intel.com>
Date: Wed, 10 Apr 2024 02:09:49 +0800
Subject: [PATCH] Add OpenVINO embedding (#12643)

* first commit

update dependancy

update the model configuration

* fix CI issue
---
 docs/docs/examples/embeddings/openvino.ipynb  | 124 ++++++++++++
 .../BUILD                                     |   3 +
 .../Makefile                                  |  17 ++
 .../README.md                                 |   1 +
 .../embeddings/huggingface_openvino/BUILD     |   1 +
 .../huggingface_openvino/__init__.py          |   3 +
 .../embeddings/huggingface_openvino/base.py   | 181 ++++++++++++++++++
 .../pyproject.toml                            |  67 +++++++
 .../tests/BUILD                               |   1 +
 .../tests/__init__.py                         |   0
 .../test_embeddings_huggingface_openvino.py   |  28 +++
 11 files changed, 426 insertions(+)
 create mode 100644 docs/docs/examples/embeddings/openvino.ipynb
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/BUILD
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/Makefile
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/README.md
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/BUILD
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/__init__.py
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/base.py
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/pyproject.toml
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/BUILD
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/__init__.py
 create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/test_embeddings_huggingface_openvino.py

diff --git a/docs/docs/examples/embeddings/openvino.ipynb b/docs/docs/examples/embeddings/openvino.ipynb
new file mode 100644
index 000000000..34d704034
--- /dev/null
+++ b/docs/docs/examples/embeddings/openvino.ipynb
@@ -0,0 +1,124 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Local Embeddings with OpenVINO\n",
+    "\n",
+    "[OpenVINOâ„¢](https://github.com/openvinotoolkit/openvino) is an open-source toolkit for optimizing and deploying AI inference. The OpenVINOâ„¢ Runtime supports various hardware [devices](https://github.com/openvinotoolkit/openvino?tab=readme-ov-file#supported-hardware-matrix) including x86 and ARM CPUs, and Intel GPUs. It can help to boost deep learning performance in Computer Vision, Automatic Speech Recognition, Natural Language Processing and other common tasks.\n",
+    "\n",
+    "Hugging Face embedding model can be supported by OpenVINO through ``OpenVINOEmbedding`` class."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%pip install llama-index-embeddings-huggingface-openvino"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install llama-index"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Model Exporter\n",
+    "\n",
+    "It is possible to export your model to the OpenVINO IR format with `create_and_save_openvino_model` function, and load the model from local folder."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.embeddings.huggingface_openvino import OpenVINOEmbedding\n",
+    "\n",
+    "OpenVINOEmbedding.create_and_save_openvino_model(\n",
+    "    \"BAAI/bge-small-en-v1.5\", \"./bge_ov\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Model Loading\n",
+    "If you have an Intel GPU, you can specify `device=\"gpu\"` to run inference on it."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ov_embed_model = OpenVINOEmbedding(folder_name=\"./bge_ov\", device=\"cpu\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "embeddings = ov_embed_model.get_text_embedding(\"Hello World!\")\n",
+    "print(len(embeddings))\n",
+    "print(embeddings[:5])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "For more information refer to:\n",
+    "\n",
+    "* [OpenVINO LLM guide](https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide.html).\n",
+    "\n",
+    "* [OpenVINO Documentation](https://docs.openvino.ai/2024/home.html).\n",
+    "\n",
+    "* [OpenVINO Get Started Guide](https://www.intel.com/content/www/us/en/content-details/819067/openvino-get-started-guide.html)."
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/BUILD
new file mode 100644
index 000000000..0896ca890
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/BUILD
@@ -0,0 +1,3 @@
+poetry_requirements(
+    name="poetry",
+)
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/Makefile b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/Makefile
new file mode 100644
index 000000000..b9eab05aa
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/Makefile
@@ -0,0 +1,17 @@
+GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
+
+help:	## Show all Makefile targets.
+	@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
+
+format:	## Run code autoformatters (black).
+	pre-commit install
+	git ls-files | xargs pre-commit run black --files
+
+lint:	## Run linters: pre-commit (black, ruff, codespell) and mypy
+	pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
+
+test:	## Run tests via pytest.
+	pytest tests
+
+watch-docs:	## Build and watch documentation.
+	sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/README.md b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/README.md
new file mode 100644
index 000000000..d14464fb4
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/README.md
@@ -0,0 +1 @@
+# LlamaIndex Embeddings Integration: Huggingface OpenVINO
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/BUILD
new file mode 100644
index 000000000..db46e8d6c
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/BUILD
@@ -0,0 +1 @@
+python_sources()
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/__init__.py
new file mode 100644
index 000000000..91b1ac182
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/__init__.py
@@ -0,0 +1,3 @@
+from llama_index.embeddings.huggingface_openvino.base import OpenVINOEmbedding
+
+__all__ = ["OpenVINOEmbedding"]
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/base.py
new file mode 100644
index 000000000..0a164040c
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/llama_index/embeddings/huggingface_openvino/base.py
@@ -0,0 +1,181 @@
+from typing import Any, List, Optional, Dict
+
+from llama_index.core.base.embeddings.base import (
+    DEFAULT_EMBED_BATCH_SIZE,
+    BaseEmbedding,
+)
+from llama_index.core.bridge.pydantic import Field, PrivateAttr
+from llama_index.core.callbacks import CallbackManager
+from llama_index.embeddings.huggingface.utils import format_query, format_text
+from optimum.intel.openvino import OVModelForFeatureExtraction
+from transformers import AutoTokenizer
+
+
+class OpenVINOEmbedding(BaseEmbedding):
+    folder_name: str = Field(description="Folder name to load from.")
+    max_length: int = Field(description="Maximum length of input.")
+    pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].")
+    normalize: str = Field(default=True, description="Normalize embeddings or not.")
+    query_instruction: Optional[str] = Field(
+        description="Instruction to prepend to query text."
+    )
+    text_instruction: Optional[str] = Field(
+        description="Instruction to prepend to text."
+    )
+    cache_folder: Optional[str] = Field(
+        description="Cache folder for huggingface files."
+    )
+
+    _model: Any = PrivateAttr()
+    _tokenizer: Any = PrivateAttr()
+    _device: Any = PrivateAttr()
+
+    def __init__(
+        self,
+        folder_name: str,
+        pooling: str = "cls",
+        max_length: Optional[int] = None,
+        normalize: bool = True,
+        query_instruction: Optional[str] = None,
+        text_instruction: Optional[str] = None,
+        model: Optional[Any] = None,
+        tokenizer: Optional[Any] = None,
+        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
+        callback_manager: Optional[CallbackManager] = None,
+        model_kwargs: Dict[str, Any] = {},
+        device: Optional[str] = "auto",
+    ):
+        self._device = device
+        self._model = model or OVModelForFeatureExtraction.from_pretrained(
+            folder_name, device=self._device, **model_kwargs
+        )
+        self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name)
+
+        if max_length is None:
+            try:
+                max_length = int(self._model.config.max_position_embeddings)
+            except Exception:
+                raise ValueError(
+                    "Unable to find max_length from model config. "
+                    "Please provide max_length."
+                )
+            try:
+                max_length = min(max_length, int(self._tokenizer.model_max_length))
+            except Exception as exc:
+                print(f"An error occurred while retrieving tokenizer max length: {exc}")
+
+        if pooling not in ["cls", "mean"]:
+            raise ValueError(f"Pooling {pooling} not supported.")
+
+        super().__init__(
+            embed_batch_size=embed_batch_size,
+            callback_manager=callback_manager,
+            folder_name=folder_name,
+            max_length=max_length,
+            pooling=pooling,
+            normalize=normalize,
+            query_instruction=query_instruction,
+            text_instruction=text_instruction,
+        )
+
+    @classmethod
+    def class_name(cls) -> str:
+        return "OpenVINOEmbedding"
+
+    @classmethod
+    def create_and_save_openvino_model(
+        cls,
+        model_name_or_path: str,
+        output_path: str,
+        export_kwargs: Optional[dict] = None,
+    ) -> None:
+        try:
+            from optimum.intel.openvino import OVModelForFeatureExtraction
+            from transformers import AutoTokenizer
+        except ImportError:
+            raise ImportError(
+                "OptimumEmbedding requires transformers to be installed.\n"
+                "Please install transformers with "
+                "`pip install transformers optimum[openvino]`."
+            )
+
+        export_kwargs = export_kwargs or {}
+        model = OVModelForFeatureExtraction.from_pretrained(
+            model_name_or_path, export=True, **export_kwargs
+        )
+        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+
+        model.save_pretrained(output_path)
+        tokenizer.save_pretrained(output_path)
+        print(
+            f"Saved OpenVINO model to {output_path}. Use it with "
+            f"`embed_model = OpenVINOEmbedding(folder_name='{output_path}')`."
+        )
+
+    def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
+        """Mean Pooling - Take attention mask into account for correct averaging."""
+        import torch
+
+        # First element of model_output contains all token embeddings
+        token_embeddings = model_output[0]
+        input_mask_expanded = (
+            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+        )
+        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
+            input_mask_expanded.sum(1), min=1e-9
+        )
+
+    def _cls_pooling(self, model_output: list) -> Any:
+        """Use the CLS token as the pooling token."""
+        return model_output[0][:, 0]
+
+    def _embed(self, sentences: List[str]) -> List[List[float]]:
+        """Embed sentences."""
+        encoded_input = self._tokenizer(
+            sentences,
+            padding=True,
+            max_length=self.max_length,
+            truncation=True,
+            return_tensors="pt",
+        )
+
+        model_output = self._model(**encoded_input)
+
+        if self.pooling == "cls":
+            embeddings = self._cls_pooling(model_output)
+        else:
+            embeddings = self._mean_pooling(
+                model_output, encoded_input["attention_mask"]
+            )
+
+        if self.normalize:
+            import torch
+
+            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+
+        return embeddings.tolist()
+
+    def _get_query_embedding(self, query: str) -> List[float]:
+        """Get query embedding."""
+        query = format_query(query, self.model_name, self.query_instruction)
+        return self._embed([query])[0]
+
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        """Get query embedding async."""
+        return self._get_query_embedding(query)
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        """Get text embedding async."""
+        return self._get_text_embedding(text)
+
+    def _get_text_embedding(self, text: str) -> List[float]:
+        """Get text embedding."""
+        text = format_text(text, self.model_name, self.text_instruction)
+        return self._embed([text])[0]
+
+    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
+        """Get text embeddings."""
+        texts = [
+            format_text(text, self.model_name, self.text_instruction) for text in texts
+        ]
+        return self._embed(texts)
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/pyproject.toml
new file mode 100644
index 000000000..72e48ebca
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/pyproject.toml
@@ -0,0 +1,67 @@
+[build-system]
+build-backend = "poetry.core.masonry.api"
+requires = ["poetry-core"]
+
+[tool.codespell]
+check-filenames = true
+check-hidden = true
+skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
+
+[tool.llamahub]
+contains_example = false
+import_path = "llama_index.embeddings.openvino"
+
+[tool.llamahub.class_authors]
+OpenVINOEmbedding = "llama-index"
+
+[tool.mypy]
+disallow_untyped_defs = true
+exclude = ["_static", "build", "examples", "notebooks", "venv"]
+ignore_missing_imports = true
+python_version = "3.8"
+
+[tool.poetry]
+authors = ["Your Name <you@example.com>"]
+description = "llama-index embeddings openvino integration"
+exclude = ["**/BUILD"]
+license = "MIT"
+name = "llama-index-embeddings-openvino"
+readme = "README.md"
+version = "0.1.5"
+
+[tool.poetry.dependencies]
+python = ">=3.8.1,<4.0"
+llama-index-core = "^0.10.1"
+llama-index-embeddings-huggingface = "^0.1.3"
+
+[tool.poetry.dependencies.optimum]
+extras = ["openvino"]
+version = "^1.18.0"
+
+[tool.poetry.group.dev.dependencies]
+ipython = "8.10.0"
+jupyter = "^1.0.0"
+mypy = "0.991"
+pre-commit = "3.2.0"
+pylint = "2.15.10"
+pytest = "7.2.1"
+pytest-mock = "3.11.1"
+ruff = "0.0.292"
+tree-sitter-languages = "^1.8.0"
+types-Deprecated = ">=0.1.0"
+types-PyYAML = "^6.0.12.12"
+types-protobuf = "^4.24.0.4"
+types-redis = "4.5.5.0"
+types-requests = "2.28.11.8"
+types-setuptools = "67.1.0.0"
+
+[tool.poetry.group.dev.dependencies.black]
+extras = ["jupyter"]
+version = "<=23.9.1,>=23.7.0"
+
+[tool.poetry.group.dev.dependencies.codespell]
+extras = ["toml"]
+version = ">=v2.2.6"
+
+[[tool.poetry.packages]]
+include = "llama_index/"
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/BUILD
new file mode 100644
index 000000000..dabf212d7
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/BUILD
@@ -0,0 +1 @@
+python_tests()
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/test_embeddings_huggingface_openvino.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/test_embeddings_huggingface_openvino.py
new file mode 100644
index 000000000..63aa8f2a3
--- /dev/null
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-openvino/tests/test_embeddings_huggingface_openvino.py
@@ -0,0 +1,28 @@
+from llama_index.core.base.embeddings.base import BaseEmbedding
+from llama_index.embeddings.huggingface_openvino import OpenVINOEmbedding
+
+
+def test_openvinoembedding_class():
+    names_of_base_classes = [b.__name__ for b in OpenVINOEmbedding.__mro__]
+    assert BaseEmbedding.__name__ in names_of_base_classes
+
+
+def test_openvinoembedding_get_text_embedding(tmp_path):
+    model_dir = str(tmp_path / "models/bge_ov")
+    OpenVINOEmbedding.create_and_save_openvino_model(
+        "BAAI/bge-small-en-v1.5", model_dir
+    )
+    embed_model = OpenVINOEmbedding(folder_name=model_dir)
+    embeddings = embed_model.get_text_embedding("Hello World!")
+
+    assert len(embeddings) == 384
+    gold_embeddings = [
+        -0.0032756966538727283,
+        -0.011690770275890827,
+        0.04155917093157768,
+        -0.038148097693920135,
+        0.024183034896850586,
+    ]
+
+    for i in range(5):
+        assert abs(embeddings[i] - gold_embeddings[i]) < 1e-4
-- 
GitLab