Skip to content
Snippets Groups Projects
Unverified Commit 029f7433 authored by Ethan Yang's avatar Ethan Yang Committed by GitHub
Browse files

Add OpenVINO embedding (#12643)

* first commit

update dependancy

update the model configuration

* fix CI issue
parent 08684b50
No related branches found
No related tags found
No related merge requests found
Showing
with 426 additions and 0 deletions
%% Cell type:markdown id: tags:
# Local Embeddings with OpenVINO
[OpenVINO™](https://github.com/openvinotoolkit/openvino) is an open-source toolkit for optimizing and deploying AI inference. The OpenVINO™ Runtime supports various hardware [devices](https://github.com/openvinotoolkit/openvino?tab=readme-ov-file#supported-hardware-matrix) including x86 and ARM CPUs, and Intel GPUs. It can help to boost deep learning performance in Computer Vision, Automatic Speech Recognition, Natural Language Processing and other common tasks.
Hugging Face embedding model can be supported by OpenVINO through ``OpenVINOEmbedding`` class.
%% Cell type:markdown id: tags:
If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.
%% Cell type:code id: tags:
``` python
%pip install llama-index-embeddings-huggingface-openvino
```
%% Cell type:code id: tags:
``` python
!pip install llama-index
```
%% Cell type:markdown id: tags:
## Model Exporter
It is possible to export your model to the OpenVINO IR format with `create_and_save_openvino_model` function, and load the model from local folder.
%% Cell type:code id: tags:
``` python
from llama_index.embeddings.huggingface_openvino import OpenVINOEmbedding
OpenVINOEmbedding.create_and_save_openvino_model(
"BAAI/bge-small-en-v1.5", "./bge_ov"
)
```
%% Cell type:markdown id: tags:
## Model Loading
If you have an Intel GPU, you can specify `device="gpu"` to run inference on it.
%% Cell type:code id: tags:
``` python
ov_embed_model = OpenVINOEmbedding(folder_name="./bge_ov", device="cpu")
```
%% Cell type:code id: tags:
``` python
embeddings = ov_embed_model.get_text_embedding("Hello World!")
print(len(embeddings))
print(embeddings[:5])
```
%% Cell type:markdown id: tags:
For more information refer to:
* [OpenVINO LLM guide](https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide.html).
* [OpenVINO Documentation](https://docs.openvino.ai/2024/home.html).
* [OpenVINO Get Started Guide](https://www.intel.com/content/www/us/en/content-details/819067/openvino-get-started-guide.html).
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Embeddings Integration: Huggingface OpenVINO
from llama_index.embeddings.huggingface_openvino.base import OpenVINOEmbedding
__all__ = ["OpenVINOEmbedding"]
from typing import Any, List, Optional, Dict
from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
from llama_index.embeddings.huggingface.utils import format_query, format_text
from optimum.intel.openvino import OVModelForFeatureExtraction
from transformers import AutoTokenizer
class OpenVINOEmbedding(BaseEmbedding):
folder_name: str = Field(description="Folder name to load from.")
max_length: int = Field(description="Maximum length of input.")
pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].")
normalize: str = Field(default=True, description="Normalize embeddings or not.")
query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text."
)
text_instruction: Optional[str] = Field(
description="Instruction to prepend to text."
)
cache_folder: Optional[str] = Field(
description="Cache folder for huggingface files."
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_device: Any = PrivateAttr()
def __init__(
self,
folder_name: str,
pooling: str = "cls",
max_length: Optional[int] = None,
normalize: bool = True,
query_instruction: Optional[str] = None,
text_instruction: Optional[str] = None,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
model_kwargs: Dict[str, Any] = {},
device: Optional[str] = "auto",
):
self._device = device
self._model = model or OVModelForFeatureExtraction.from_pretrained(
folder_name, device=self._device, **model_kwargs
)
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name)
if max_length is None:
try:
max_length = int(self._model.config.max_position_embeddings)
except Exception:
raise ValueError(
"Unable to find max_length from model config. "
"Please provide max_length."
)
try:
max_length = min(max_length, int(self._tokenizer.model_max_length))
except Exception as exc:
print(f"An error occurred while retrieving tokenizer max length: {exc}")
if pooling not in ["cls", "mean"]:
raise ValueError(f"Pooling {pooling} not supported.")
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
folder_name=folder_name,
max_length=max_length,
pooling=pooling,
normalize=normalize,
query_instruction=query_instruction,
text_instruction=text_instruction,
)
@classmethod
def class_name(cls) -> str:
return "OpenVINOEmbedding"
@classmethod
def create_and_save_openvino_model(
cls,
model_name_or_path: str,
output_path: str,
export_kwargs: Optional[dict] = None,
) -> None:
try:
from optimum.intel.openvino import OVModelForFeatureExtraction
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"OptimumEmbedding requires transformers to be installed.\n"
"Please install transformers with "
"`pip install transformers optimum[openvino]`."
)
export_kwargs = export_kwargs or {}
model = OVModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=True, **export_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print(
f"Saved OpenVINO model to {output_path}. Use it with "
f"`embed_model = OpenVINOEmbedding(folder_name='{output_path}')`."
)
def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
"""Mean Pooling - Take attention mask into account for correct averaging."""
import torch
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def _cls_pooling(self, model_output: list) -> Any:
"""Use the CLS token as the pooling token."""
return model_output[0][:, 0]
def _embed(self, sentences: List[str]) -> List[List[float]]:
"""Embed sentences."""
encoded_input = self._tokenizer(
sentences,
padding=True,
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
model_output = self._model(**encoded_input)
if self.pooling == "cls":
embeddings = self._cls_pooling(model_output)
else:
embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
)
if self.normalize:
import torch
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.tolist()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
query = format_query(query, self.model_name, self.query_instruction)
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async."""
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
return self._get_text_embedding(text)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
text = format_text(text, self.model_name, self.text_instruction)
return self._embed([text])[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
texts = [
format_text(text, self.model_name, self.text_instruction) for text in texts
]
return self._embed(texts)
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.embeddings.openvino"
[tool.llamahub.class_authors]
OpenVINOEmbedding = "llama-index"
[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
description = "llama-index embeddings openvino integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-openvino"
readme = "README.md"
version = "0.1.5"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
llama-index-embeddings-huggingface = "^0.1.3"
[tool.poetry.dependencies.optimum]
extras = ["openvino"]
version = "^1.18.0"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"
[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"
[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"
[[tool.poetry.packages]]
include = "llama_index/"
python_tests()
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface_openvino import OpenVINOEmbedding
def test_openvinoembedding_class():
names_of_base_classes = [b.__name__ for b in OpenVINOEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes
def test_openvinoembedding_get_text_embedding(tmp_path):
model_dir = str(tmp_path / "models/bge_ov")
OpenVINOEmbedding.create_and_save_openvino_model(
"BAAI/bge-small-en-v1.5", model_dir
)
embed_model = OpenVINOEmbedding(folder_name=model_dir)
embeddings = embed_model.get_text_embedding("Hello World!")
assert len(embeddings) == 384
gold_embeddings = [
-0.0032756966538727283,
-0.011690770275890827,
0.04155917093157768,
-0.038148097693920135,
0.024183034896850586,
]
for i in range(5):
assert abs(embeddings[i] - gold_embeddings[i]) < 1e-4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment