diff --git a/llama_index/vector_stores/astra.py b/llama_index/vector_stores/astra.py index a574f38d3b2bea6e5accddceb6ccb01685f5e16a..22fa90eec749cd75df94ad3bcbc2839b3f874a1d 100644 --- a/llama_index/vector_stores/astra.py +++ b/llama_index/vector_stores/astra.py @@ -8,6 +8,7 @@ powered by the astrapy library import json import logging from typing import Any, Dict, List, Optional, cast +from warnings import warn from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings from llama_index.schema import BaseNode, MetadataMode @@ -31,6 +32,8 @@ _logger = logging.getLogger(__name__) DEFAULT_MMR_PREFETCH_FACTOR = 4.0 MAX_INSERT_BATCH_SIZE = 20 +NON_INDEXED_FIELDS = ["metadata._node_content", "content"] + class AstraDBVectorStore(VectorStore): """ @@ -89,10 +92,67 @@ class AstraDBVectorStore(VectorStore): api_endpoint=api_endpoint, token=token, namespace=namespace ) - # Create and connect to the newly created collection - self._astra_db_collection = self._astra_db.create_collection( - collection_name=collection_name, dimension=embedding_dimension - ) + from astrapy.api import APIRequestError + + try: + # Create and connect to the newly created collection + self._astra_db_collection = self._astra_db.create_collection( + collection_name=collection_name, + dimension=embedding_dimension, + options={"indexing": {"deny": NON_INDEXED_FIELDS}}, + ) + except APIRequestError as e: + # possibly the collection is preexisting and has legacy + # indexing settings: verify + get_coll_response = self._astra_db.get_collections( + options={"explain": True} + ) + collections = (get_coll_response["status"] or {}).get("collections") or [] + preexisting = [ + collection + for collection in collections + if collection["name"] == collection_name + ] + if preexisting: + pre_collection = preexisting[0] + # if it has no "indexing", it is a legacy collection; + # otherwise it's unexpected warn and proceed at user's risk + pre_col_options = pre_collection.get("options") or {} + if "indexing" not in pre_col_options: + warn( + ( + f"Collection '{collection_name}' is detected as legacy" + " and has indexing turned on for all fields. This" + " implies stricter limitations on the amount of text" + " each entry can store. Consider reindexing anew on a" + " fresh collection to be able to store longer texts." + ), + UserWarning, + stacklevel=2, + ) + self._astra_db_collection = self._astra_db.collection( + collection_name=collection_name, + ) + else: + options_json = json.dumps(pre_col_options["indexing"]) + warn( + ( + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider reindexing anew on a" + " fresh collection." + ), + UserWarning, + stacklevel=2, + ) + self._astra_db_collection = self._astra_db.collection( + collection_name=collection_name, + ) + else: + # other exception + raise def add( self, diff --git a/tests/vector_stores/test_astra.py b/tests/vector_stores/test_astra.py index 5b67cd728f53a26836ba5d825ca3f6200f42d852..f46878c61e8c4438b147fb60fd98bece3778026e 100644 --- a/tests/vector_stores/test_astra.py +++ b/tests/vector_stores/test_astra.py @@ -1,4 +1,5 @@ -import unittest +import os +from typing import Iterable import pytest from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode @@ -15,45 +16,54 @@ except ImportError: has_astrapy = False -def get_astra_db_store() -> AstraDBVectorStore: - return AstraDBVectorStore( - token="AstraCS:<...>", - api_endpoint=f"https://<...>", +# env variables +ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") +ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT", "") + + +@pytest.fixture(scope="module") +def astra_db_store() -> Iterable[AstraDBVectorStore]: + store = AstraDBVectorStore( + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, collection_name="test_collection", embedding_dimension=2, - namespace="default_keyspace", - ttl_seconds=123, ) + yield store + store._astra_db.delete_collection("test_collection") -class TestAstraDBVectorStore(unittest.TestCase): - @pytest.mark.skipif(not has_astrapy, reason="astrapy not installed") - def test_astra_db_create_and_crud(self) -> None: - vector_store = get_astra_db_store() - - vector_store.add( - [ - TextNode( - text="test node text", - id_="test node id", - relationships={ - NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test doc id") - }, - embedding=[0.5, 0.5], - ) - ] - ) - vector_store.delete("test node id") +@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed") +@pytest.mark.skipif( + ASTRA_DB_APPLICATION_TOKEN == "" or ASTRA_DB_API_ENDPOINT == "", + reason="missing Astra DB credentials", +) +def test_astra_db_create_and_crud(astra_db_store: AstraDBVectorStore) -> None: + astra_db_store.add( + [ + TextNode( + text="test node text", + id_="test node id", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test doc id") + }, + embedding=[0.5, 0.5], + ) + ] + ) - vector_store.client + astra_db_store.delete("test node id") - @pytest.mark.skipif(not has_astrapy, reason="astrapy not installed") - def test_astra_db_queries(self) -> None: - vector_store = get_astra_db_store() - query = VectorStoreQuery(query_embedding=[1, 1], similarity_top_k=3) +@pytest.mark.skipif(not has_astrapy, reason="astrapy not installed") +@pytest.mark.skipif( + ASTRA_DB_APPLICATION_TOKEN == "" or ASTRA_DB_API_ENDPOINT == "", + reason="missing Astra DB credentials", +) +def test_astra_db_queries(astra_db_store: AstraDBVectorStore) -> None: + query = VectorStoreQuery(query_embedding=[1, 1], similarity_top_k=3) - vector_store.query( - query, - ) + astra_db_store.query( + query, + )