Skip to content
Snippets Groups Projects
Unverified Commit 985b3e96 authored by Stefano Lottini's avatar Stefano Lottini Committed by GitHub
Browse files

[Astra DB] Use the "indexing" API option to unlock storage of long document...

[Astra DB] Use the "indexing" API option to unlock storage of long document texts in vector store (#10423)
parent 23af82e3
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ powered by the astrapy library
import json
import logging
from typing import Any, Dict, List, Optional, cast
from warnings import warn
from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings
from llama_index.schema import BaseNode, MetadataMode
......@@ -31,6 +32,8 @@ _logger = logging.getLogger(__name__)
DEFAULT_MMR_PREFETCH_FACTOR = 4.0
MAX_INSERT_BATCH_SIZE = 20
NON_INDEXED_FIELDS = ["metadata._node_content", "content"]
class AstraDBVectorStore(VectorStore):
"""
......@@ -89,10 +92,67 @@ class AstraDBVectorStore(VectorStore):
api_endpoint=api_endpoint, token=token, namespace=namespace
)
# Create and connect to the newly created collection
self._astra_db_collection = self._astra_db.create_collection(
collection_name=collection_name, dimension=embedding_dimension
)
from astrapy.api import APIRequestError
try:
# Create and connect to the newly created collection
self._astra_db_collection = self._astra_db.create_collection(
collection_name=collection_name,
dimension=embedding_dimension,
options={"indexing": {"deny": NON_INDEXED_FIELDS}},
)
except APIRequestError as e:
# possibly the collection is preexisting and has legacy
# indexing settings: verify
get_coll_response = self._astra_db.get_collections(
options={"explain": True}
)
collections = (get_coll_response["status"] or {}).get("collections") or []
preexisting = [
collection
for collection in collections
if collection["name"] == collection_name
]
if preexisting:
pre_collection = preexisting[0]
# if it has no "indexing", it is a legacy collection;
# otherwise it's unexpected warn and proceed at user's risk
pre_col_options = pre_collection.get("options") or {}
if "indexing" not in pre_col_options:
warn(
(
f"Collection '{collection_name}' is detected as legacy"
" and has indexing turned on for all fields. This"
" implies stricter limitations on the amount of text"
" each entry can store. Consider reindexing anew on a"
" fresh collection to be able to store longer texts."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
)
else:
options_json = json.dumps(pre_col_options["indexing"])
warn(
(
f"Collection '{collection_name}' has unexpected 'indexing'"
f" settings (options.indexing = {options_json})."
" This can result in odd behaviour when running "
" metadata filtering and/or unwarranted limitations"
" on storing long texts. Consider reindexing anew on a"
" fresh collection."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
)
else:
# other exception
raise
def add(
self,
......
import unittest
import os
from typing import Iterable
import pytest
from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode
......@@ -15,45 +16,54 @@ except ImportError:
has_astrapy = False
def get_astra_db_store() -> AstraDBVectorStore:
return AstraDBVectorStore(
token="AstraCS:<...>",
api_endpoint=f"https://<...>",
# env variables
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT", "")
@pytest.fixture(scope="module")
def astra_db_store() -> Iterable[AstraDBVectorStore]:
store = AstraDBVectorStore(
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
collection_name="test_collection",
embedding_dimension=2,
namespace="default_keyspace",
ttl_seconds=123,
)
yield store
store._astra_db.delete_collection("test_collection")
class TestAstraDBVectorStore(unittest.TestCase):
@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed")
def test_astra_db_create_and_crud(self) -> None:
vector_store = get_astra_db_store()
vector_store.add(
[
TextNode(
text="test node text",
id_="test node id",
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test doc id")
},
embedding=[0.5, 0.5],
)
]
)
vector_store.delete("test node id")
@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed")
@pytest.mark.skipif(
ASTRA_DB_APPLICATION_TOKEN == "" or ASTRA_DB_API_ENDPOINT == "",
reason="missing Astra DB credentials",
)
def test_astra_db_create_and_crud(astra_db_store: AstraDBVectorStore) -> None:
astra_db_store.add(
[
TextNode(
text="test node text",
id_="test node id",
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test doc id")
},
embedding=[0.5, 0.5],
)
]
)
vector_store.client
astra_db_store.delete("test node id")
@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed")
def test_astra_db_queries(self) -> None:
vector_store = get_astra_db_store()
query = VectorStoreQuery(query_embedding=[1, 1], similarity_top_k=3)
@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed")
@pytest.mark.skipif(
ASTRA_DB_APPLICATION_TOKEN == "" or ASTRA_DB_API_ENDPOINT == "",
reason="missing Astra DB credentials",
)
def test_astra_db_queries(astra_db_store: AstraDBVectorStore) -> None:
query = VectorStoreQuery(query_embedding=[1, 1], similarity_top_k=3)
vector_store.query(
query,
)
astra_db_store.query(
query,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment