Skip to content
Snippets Groups Projects
Unverified Commit d78f5cb0 authored by Justin Hu's avatar Justin Hu Committed by GitHub
Browse files

key value store for elasticsearch (#12068)

parent cafbdfdd
No related branches found
No related tags found
No related merge requests found
Showing
with 567 additions and 0 deletions
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Kvstore Integration: Elasticsearch Kvstore
from llama_index.storage.kvstore.elasticsearch.base import ElasticsearchKVStore
__all__ = ["ElasticsearchKVStore"]
from typing import Any, Dict, List, Optional, Tuple
from logging import getLogger
from llama_index.core.storage.kvstore.types import (
DEFAULT_BATCH_SIZE,
DEFAULT_COLLECTION,
BaseKVStore,
)
import asyncio
import nest_asyncio
import elasticsearch
from elasticsearch.helpers import async_bulk
logger = getLogger(__name__)
IMPORT_ERROR_MSG = (
"`elasticsearch` package not found, please run `pip install elasticsearch`"
)
def _get_elasticsearch_client(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> elasticsearch.AsyncElasticsearch:
"""Get AsyncElasticsearch client.
Args:
es_url: Elasticsearch URL.
cloud_id: Elasticsearch cloud ID.
api_key: Elasticsearch API key.
username: Elasticsearch username.
password: Elasticsearch password.
Returns:
AsyncElasticsearch client.
Raises:
ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
"""
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
sync_es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchKVStore.get_user_agent()},
)
async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)
try:
sync_es_client.info() # so don't have to 'await' to just get info
except Exception as e:
logger.error(f"Error connecting to Elasticsearch: {e}")
raise
return async_es_client
class ElasticsearchKVStore(BaseKVStore):
"""Elasticsearch Key-Value store.
Args:
index_name: Name of the Elasticsearch index.
es_client: Optional. Pre-existing AsyncElasticsearch client.
es_url: Optional. Elasticsearch URL.
es_cloud_id: Optional. Elasticsearch cloud ID.
es_api_key: Optional. Elasticsearch API key.
es_user: Optional. Elasticsearch username.
es_password: Optional. Elasticsearch password.
Raises:
ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
"""
es_client: Optional[Any]
es_url: Optional[str]
es_cloud_id: Optional[str]
es_api_key: Optional[str]
es_user: Optional[str]
es_password: Optional[str]
def __init__(
self,
index_name: str,
es_client: Optional[Any],
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_api_key: Optional[str] = None,
es_user: Optional[str] = None,
es_password: Optional[str] = None,
) -> None:
nest_asyncio.apply()
"""Init a ElasticsearchKVStore."""
try:
from elasticsearch import AsyncElasticsearch
except ImportError:
raise ImportError(IMPORT_ERROR_MSG)
if es_client is not None:
self._client = es_client.options(
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None:
self._client: AsyncElasticsearch = _get_elasticsearch_client(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
else:
raise ValueError(
"""Either provide a pre-existing AsyncElasticsearch or valid \
credentials for creating a new connection."""
)
@property
def client(self) -> Any:
"""Get async elasticsearch client."""
return self._client
@staticmethod
def get_user_agent() -> str:
"""Get user agent for elasticsearch client."""
return "llama_index-py-vs"
async def _create_index_if_not_exists(self, index_name: str) -> None:
"""Create the AsyncElasticsearch index if it doesn't already exist.
Args:
index_name: Name of the AsyncElasticsearch index to create.
"""
if await self.client.indices.exists(index=index_name):
logger.debug(f"Index {index_name} already exists. Skipping creation.")
else:
index_settings = {"mappings": {"_source": {"enabled": True}}}
logger.debug(
f"Creating index {index_name} with mappings {index_settings['mappings']}"
)
await self.client.indices.create(index=index_name, **index_settings)
def put(
self,
key: str,
val: dict,
collection: str = DEFAULT_COLLECTION,
) -> None:
"""Put a key-value pair into the store.
Args:
key (str): key
val (dict): value
collection (str): collection name
"""
self.put_all([(key, val)], collection=collection)
async def aput(
self,
key: str,
val: dict,
collection: str = DEFAULT_COLLECTION,
) -> None:
"""Put a key-value pair into the store.
Args:
key (str): key
val (dict): value
collection (str): collection name
"""
await self.aput_all([(key, val)], collection=collection)
def put_all(
self,
kv_pairs: List[Tuple[str, dict]],
collection: str = DEFAULT_COLLECTION,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> None:
return asyncio.get_event_loop().run_until_complete(
self.aput_all(kv_pairs, collection, batch_size)
)
async def aput_all(
self,
kv_pairs: List[Tuple[str, dict]],
collection: str = DEFAULT_COLLECTION,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> None:
await self._create_index_if_not_exists(collection)
# Prepare documents with '_id' set to the key for batch insertion
docs = [{"_id": key, **value} for key, value in kv_pairs]
# Insert documents in batches
for batch in (
docs[i : i + batch_size] for i in range(0, len(docs), batch_size)
):
requests = []
for doc in batch:
doc_id = doc["_id"]
doc.pop("_id")
logger.debug(doc)
request = {
"_op_type": "index",
"_index": collection,
**doc,
"_id": doc_id,
}
requests.append(request)
await async_bulk(self.client, requests, chunk_size=batch_size, refresh=True)
def get(self, key: str, collection: str = DEFAULT_COLLECTION) -> Optional[dict]:
"""Get a value from the store.
Args:
key (str): key
collection (str): collection name
"""
return asyncio.get_event_loop().run_until_complete(self.aget(key, collection))
async def aget(
self, key: str, collection: str = DEFAULT_COLLECTION
) -> Optional[dict]:
"""Get a value from the store.
Args:
key (str): key
collection (str): collection name
"""
await self._create_index_if_not_exists(collection)
try:
response = await self._client.get(index=collection, id=key, source=True)
return response.body["_source"]
except elasticsearch.NotFoundError:
return None
def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]:
"""Get all values from the store.
Args:
collection (str): collection name
"""
return asyncio.get_event_loop().run_until_complete(self.aget_all(collection))
async def aget_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]:
"""Get all values from the store.
Args:
collection (str): collection name
"""
await self._create_index_if_not_exists(collection)
q = {"query": {"match_all": {}}}
response = await self._client.search(index=collection, body=q, source=True)
result = {}
for r in response["hits"]["hits"]:
doc_id = r["_id"]
content = r["_source"]
result[doc_id] = content
return result
def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool:
"""Delete a value from the store.
Args:
key (str): key
collection (str): collection name
"""
return asyncio.get_event_loop().run_until_complete(
self.adelete(key, collection)
)
async def adelete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool:
"""Delete a value from the store.
Args:
key (str): key
collection (str): collection name
"""
await self._create_index_if_not_exists(collection)
try:
response = await self._client.delete(index=collection, id=key)
return response.body["result"] == "deleted"
except elasticsearch.NotFoundError:
return False
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.storage.kvstore.elasticsearch"
[tool.llamahub.class_authors]
ElasticsearchKVStore = "llama-index"
[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
description = "llama-index kvstore elasticsearch integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-storage-kvstore-elasticsearch"
readme = "README.md"
version = "0.1.2"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
elasticsearch = "^8.12.1"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"
[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"
[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"
[[tool.poetry.packages]]
include = "llama_index/"
python_tests()
from llama_index.core.storage.kvstore.types import BaseKVStore
from llama_index.storage.kvstore.elasticsearch import ElasticsearchKVStore
def test_class():
names_of_base_classes = [b.__name__ for b in ElasticsearchKVStore.__mro__]
assert BaseKVStore.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment